"""Tabular Q-learning agent (notebook)
This notebooks can be run directly from VSCode, to generate a
traditional Jupyter Notebook to open in your browser
 you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
"""
import sys
import os
import logging
from typing import cast
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt 
from cyberbattle.agents.baseline.learner import TrainedLearner
import cyberbattle.agents.baseline.plotting as p
import cyberbattle.agents.baseline.agent_wrapper as w
import cyberbattle.agents.baseline.agent_tabularqlearning as a
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
import cyberbattle.agents.baseline.learner as learner
import cyberbattle._env.cyberbattle_env as cyberbattle_env
from cyberbattle._env.cyberbattle_env import AttackerGoal
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
cyberbattlechain_10 = gym.make("CyberBattleChain-v0", size=10, attacker_goal=AttackerGoal(own_atleast_percent=1.0)).unwrapped
assert isinstance(cyberbattlechain_10, cyberbattle_env.CyberBattleEnv)
ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=12, maximum_total_credentials=12, identifiers=cyberbattlechain_10.identifiers)
iteration_count = 9000
training_episode_count = 5
eval_episode_count = 5
gamma_sweep = [
    0.015,  
]
plots_dir = 'output/plots'
os.makedirs(plots_dir, exist_ok=True)
def qlearning_run(gamma, gym_env):
    """Execute one run of the q-learning algorithm for the
    specified gamma value"""
    return learner.epsilon_greedy_search(
        gym_env,
        ep,
        a.QTabularLearner(ep, gamma=gamma, learning_rate=0.90, exploit_percentile=100),
        episode_count=training_episode_count,
        iteration_count=iteration_count,
        epsilon=0.90,
        render=False,
        epsilon_multdecay=0.75,  
        epsilon_minimum=0.01,
        verbosity=Verbosity.Quiet,
        title="Q-learning",
    )
qlearning_results = [qlearning_run(gamma, cyberbattlechain_10) for gamma in gamma_sweep]
qlearning_bestrun_10 = qlearning_results[0]
p.new_plot_loss()
for results in qlearning_results:
    p.plot_all_episodes_loss(cast(a.QTabularLearner, results["learner"]).loss_qsource.all_episodes, "Q_source", results["title"])
    p.plot_all_episodes_loss(cast(a.QTabularLearner, results["learner"]).loss_qattack.all_episodes, "Q_attack", results["title"])
plt.legend(loc="upper right")
plt.show()
p.plot_episodes_length(qlearning_results)
nolearning_results = learner.epsilon_greedy_search(
    cyberbattlechain_10,
    ep,
    learner=a.QTabularLearner(ep, trained=qlearning_bestrun_10["learner"], gamma=0.0, learning_rate=0.0, exploit_percentile=100),
    episode_count=eval_episode_count,
    iteration_count=iteration_count,
    epsilon=0.30,  
    render=False,
    title="Exploiting Q-matrix",
    verbosity=Verbosity.Quiet,
)
randomlearning_results = learner.epsilon_greedy_search(
    cyberbattlechain_10,
    ep,
    learner=a.QTabularLearner(ep, trained=qlearning_bestrun_10["learner"], gamma=0.0, learning_rate=0.0, exploit_percentile=100),
    episode_count=eval_episode_count,
    iteration_count=iteration_count,
    epsilon=1.0,  
    render=False,
    verbosity=Verbosity.Quiet,
    title="Random search",
)
all_runs = [*qlearning_results, randomlearning_results, nolearning_results]
Q_source_10 = cast(a.QTabularLearner, qlearning_bestrun_10["learner"]).qsource
Q_attack_10 = cast(a.QTabularLearner, qlearning_bestrun_10["learner"]).qattack
p.plot_averaged_cummulative_rewards(
    all_runs=all_runs,
    title=f"Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n"
    f"dimension={Q_source_10.state_space.flat_size()}x{Q_source_10.action_space.flat_size()}, "
    f"{Q_attack_10.state_space.flat_size()}x{Q_attack_10.action_space.flat_size()}\n"
    f"Q1={[f.name() for f in Q_source_10.state_space.feature_selection]} "
    f"-> {[f.name() for f in Q_source_10.action_space.feature_selection]})\n"
    f"Q2={[f.name() for f in Q_attack_10.state_space.feature_selection]} -> 'action'",
    save_at=os.path.join(plots_dir, "benchmark-tabularq-cumrewards.png")
)
p.plot_all_episodes(qlearning_results[0])
i = np.where(Q_source_10.qm)
q = Q_source_10.qm[i]
list(zip(np.array([Q_source_10.state_space.pretty_print(i) for i in i[0]]), np.array([Q_source_10.action_space.pretty_print(i) for i in i[1]]), q))
i2 = np.where(Q_attack_10.qm)
q2 = Q_attack_10.qm[i2]
list(zip([Q_attack_10.state_space.pretty_print(i) for i in i2[0]], [Q_attack_10.action_space.pretty_print(i) for i in i2[1]], q2))
cyberbattlechain_4 = gym.make("CyberBattleChain-v0", size=4, attacker_goal=AttackerGoal(own_atleast_percent=1.0)).unwrapped
assert isinstance(cyberbattlechain_4, cyberbattle_env.CyberBattleEnv)
qlearning_bestrun_4 = qlearning_run(0.015, gym_env=cyberbattlechain_4)
def stop_learning(trained_learner):
    return TrainedLearner(
        learner=a.QTabularLearner(ep, gamma=0.0, learning_rate=0.0, exploit_percentile=0, trained=trained_learner["learner"]),
        title=trained_learner["title"],
        trained_on=trained_learner["trained_on"],
        all_episodes_rewards=trained_learner["all_episodes_rewards"],
        all_episodes_availability=trained_learner["all_episodes_availability"],
    )
learner.transfer_learning_evaluation(
    environment_properties=ep,
    trained_learner=stop_learning(qlearning_bestrun_4),
    eval_env=cyberbattlechain_10,
    eval_epsilon=0.5,  
    eval_episode_count=eval_episode_count,
    iteration_count=iteration_count,
)
learner.transfer_learning_evaluation(
    environment_properties=ep,
    trained_learner=stop_learning(qlearning_bestrun_10),
    eval_env=cyberbattlechain_4,
    eval_epsilon=0.5,
    eval_episode_count=eval_episode_count,
    iteration_count=iteration_count,
)