from models import SimpleMLP
from utils import *
from args import load_args_dict
from torchsummary import summary

from args import parse_train_args
from viz.eval_simple_model_viz import evaluate_model_visually

initial_args = parse_train_args(create_dirs=False)

if args.val_split_prop == 0.0:
    args.val_split_prop = None

set_seed(manualSeed=args.seed)

device = torch.device("cuda:" + str(args.gpu_id) if torch.cuda.is_available() else "cpu")
args.device = device

base_path = args.save_path

model = SimpleMLP(hidden=args.width, depth=args.depth, fc_bias=args.bias,
                  num_classes=args.classes, penultimate_layer_features=args.classes,
                  final_activation=args.act_fn,
                  use_bn=args.use_bn).to(device)
summary(model, input_size=(3, 32, 32), batch_size=1)

args.load_path = base_path
model.load_state_dict(torch.load(args.load_path + '/epoch_' + str(200).zfill(3) + '.pth'))

evaluate_model_visually(args, model, base_path)
