Source code for analysis_engine.ai.plot_dnn_fit_history

"""
Plot a deep neural network's history output after training

Please check out this `blog post for more information on
how this works <https://
machinelearningmastery.com/custom-metrics-deep-learning-keras-python/>`__
"""

import datetime
import matplotlib.pyplot as plt
import analysis_engine.consts as ae_consts
import analysis_engine.charts as ae_charts
import analysis_engine.build_result as build_result
import spylunking.log.setup_logging as log_utils
from analysis_engine.send_to_slack import post_plot

log = log_utils.build_colorized_logger(name=__name__)


[docs]def plot_dnn_fit_history( title, df, red, red_color=None, red_label=None, blue=None, blue_color=None, blue_label=None, green=None, green_color=None, green_label=None, orange=None, orange_color=None, orange_label=None, xlabel='Training Epochs', ylabel='Error Values', linestyle='-', width=9.0, height=9.0, date_format='%d\n%b', df_filter=None, start_date=None, footnote_text=None, footnote_xpos=0.70, footnote_ypos=0.01, footnote_color='#888888', footnote_fontsize=8, scale_y=False, show_plot=True, dropna_for_all=False, verbose=False, send_plots_to_slack=False): """plot_dnn_fit_history Plot a DNN's fit history using `Keras fit history object <https://ker as.io/visualization/#training-history-visualization>`__ :param title: title of the plot :param df: dataset which is ``pandas.DataFrame`` :param red: string - column name to plot in ``red_color`` (or default ``ae_consts.PLOT_COLORS[red]``) where the column is in the ``df`` and accessible with:``df[red]`` :param red_color: hex color code to plot the data in the ``df[red]`` (default is ``ae_consts.PLOT_COLORS['red']``) :param red_label: optional - string for the label used to identify the ``red`` line in the legend :param blue: string - column name to plot in ``blue_color`` (or default ``ae_consts.PLOT_COLORS['blue']``) where the column is in the ``df`` and accessible with:``df[blue]`` :param blue_color: hex color code to plot the data in the ``df[blue]`` (default is ``ae_consts.PLOT_COLORS['blue']``) :param blue_label: optional - string for the label used to identify the ``blue`` line in the legend :param green: string - column name to plot in ``green_color`` (or default ``ae_consts.PLOT_COLORS['darkgreen']``) where the column is in the ``df`` and accessible with:``df[green]`` :param green_color: hex color code to plot the data in the ``df[green]`` (default is ``ae_consts.PLOT_COLORS['darkgreen']``) :param green_label: optional - string for the label used to identify the ``green`` line in the legend :param orange: string - column name to plot in ``orange_color`` (or default ``ae_consts.PLOT_COLORS['orange']``) where the column is in the ``df`` and accessible with:``df[orange]`` :param orange_color: hex color code to plot the data in the ``df[orange]`` (default is ``ae_consts.PLOT_COLORS['orange']``) :param orange_label: optional - string for the label used to identify the ``orange`` line in the legend :param xlabel: x-axis label :param ylabel: y-axis label :param linestyle: style of the plot line :param width: float - width of the image :param height: float - height of the image :param date_format: string - format for dates :param df_filter: optional - initialized ``pandas.DataFrame`` query for reducing the ``df`` records before plotting. As an eaxmple ``df_filter=(df['close'] > 0.01)`` would find only records in the ``df`` with a ``close`` value greater than ``0.01`` :param start_date: optional - string ``datetime`` for plotting only from a date formatted as ``YYYY-MM-DD HH\\:MM\\:SS`` :param footnote_text: optional - string footnote text (default is ``algotraders <DATE>``) :param footnote_xpos: optional - float for footnote position on the x-axies (default is ``0.75``) :param footnote_ypos: optional - float for footnote position on the y-axies (default is ``0.01``) :param footnote_color: optional - string hex color code for the footnote text (default is ``#888888``) :param footnote_fontsize: optional - float footnote font size (default is ``8``) :param scale_y: optional - bool to scale the y-axis with .. code-block:: python use_ax.set_ylim( [0, use_ax.get_ylim()[1] * 3]) :param show_plot: bool to show the plot :param dropna_for_all: optional - bool to toggle keep None's in the plot ``df`` (default is drop them for display purposes) :param verbose: optional - bool to show logs for debugging a dataset :param send_plots_to_slack: optional - bool to send the dnn plot to slack """ rec = { 'ax1': None, 'fig': None } result = build_result.build_result( status=ae_consts.NOT_RUN, err=None, rec=rec) if verbose: log.info( f'plot_dnn_fit_history - start') use_red = red_color use_blue = blue_color use_green = green_color use_orange = orange_color if not use_red: use_red = ae_consts.PLOT_COLORS['red'] if not use_blue: use_blue = ae_consts.PLOT_COLORS['blue'] if not use_green: use_green = ae_consts.PLOT_COLORS['darkgreen'] if not use_orange: use_orange = ae_consts.PLOT_COLORS['orange'] use_footnote = footnote_text if not use_footnote: use_footnote = f'''algotraders - {datetime.datetime.now().strftime( ae_consts.COMMON_TICK_DATE_FORMAT)}''' column_list = [] all_plots = [] if red: column_list.append(red) all_plots.append({ 'column': red, 'color': use_red}) if blue: column_list.append(blue) all_plots.append({ 'column': blue, 'color': use_blue}) if green: column_list.append(green) all_plots.append({ 'column': green, 'color': use_green}) if orange: column_list.append(orange) all_plots.append({ 'column': orange, 'color': use_orange}) use_df = df if hasattr(df_filter, 'to_json'): # Was seeing this warning below in Jupyter: # UserWarning: Boolean Series key # will be reindexed to match DataFrame index # use_df = use_df[df_filter][column_list] # now using: use_df = use_df.loc[df_filter, column_list] if verbose: log.info( f'plot_dnn_fit_history ' f'filter df.index={len(use_df.index)} ' f'column_list={column_list}') ae_charts.set_common_seaborn_fonts() hex_color = ae_consts.PLOT_COLORS['blue'] fig, ax = plt.subplots( sharex=True, sharey=True, figsize=( width, height)) all_axes = [ ax ] num_plots = len(all_plots) for idx, node in enumerate(all_plots): column_name = node['column'] hex_color = node['color'] use_ax = ax if verbose: log.info( f'plot_dnn_fit_history - ' f'{idx + 1}/{num_plots} - ' f'{column_name} ' f'in ' f'{hex_color} - ' f'ax={use_ax}') use_ax.plot( use_df[column_name], label=column_name, linestyle=linestyle, color=hex_color) # end if this is not the fist axis # end of for all plots lines = [] for idx, cur_ax in enumerate(all_axes): ax_lines = cur_ax.get_lines() for line in ax_lines: label_name = str(line.get_label()) use_label = label_name if idx == 0: if red_label: use_label = red_label elif idx == 1: if blue_label: use_label = blue_label elif idx == 2: use_label = label_name[-20:] if green_label: use_label = green_label elif idx == 3: use_label = label_name[-20:] if orange_label: use_label = orange_label else: if len(label_name) > 10: use_label = label_name[-20:] # end of fixing the labels in the legend line.set_label(use_label) if line.get_label() not in lines: lines.append(line) rec[f'ax{idx + 1}'] = cur_ax # end of compiling a new-shortened legend while removing dupes for idx, cur_ax in enumerate(all_axes): if cur_ax: if cur_ax.get_legend(): cur_ax.get_legend().remove() # end of removing all previous legends if verbose: log.info( f'legend lines={[l.get_label() for l in lines]}') # log what's going to be in the legend ax.legend( lines, [l.get_label() for l in lines], loc='best', shadow=True) fig.autofmt_xdate() plt.xlabel(xlabel) plt.ylabel(ylabel) ax.set_title(title) ae_charts.add_footnote( fig=fig, xpos=footnote_xpos, ypos=footnote_ypos, text=use_footnote, color=footnote_color, fontsize=footnote_fontsize) plt.tight_layout() if send_plots_to_slack: post_plot(plt, title=title) if show_plot: plt.show() else: plt.plot() rec['fig'] = fig result = build_result.build_result( status=ae_consts.SUCCESS, err=None, rec=rec) return result
# end of plot_dnn_fit_history