From a68d17029af707ba137b370f6080778a3e56fe70 Mon Sep 17 00:00:00 2001 From: Mischa Mikami Date: Sat, 7 Dec 2024 03:13:32 -0800 Subject: [PATCH] formatting issues --- color_analysis/color_analysis_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/color_analysis/color_analysis_model.py b/color_analysis/color_analysis_model.py index 0a938fc..a42c616 100644 --- a/color_analysis/color_analysis_model.py +++ b/color_analysis/color_analysis_model.py @@ -96,7 +96,6 @@ def train_model(model, train_loader, criterion, optimizer, num_epochs=10, device torch.save(model.state_dict(), save_path) -# Testing Function def test_model(model, test_loader, device='cpu'): """Evaluate the model on the test set and return accuracy""" # Model is set to evaluation mode @@ -120,6 +119,7 @@ def test_model(model, test_loader, device='cpu'): test_accuracy = correct / total * 100 return test_loss, test_accuracy + def predict_image(model, image_path, class_labels, transform, device='cpu'): """Predicts the class of a single image""" model.eval() @@ -130,6 +130,7 @@ def predict_image(model, image_path, class_labels, transform, device='cpu'): _, pred = torch.max(outputs, 1) return class_labels[pred.item()] + def load_pretrained_model(model, save_path='color_analysis/trained_model.pth', device='cpu'): """Loads a previously trained model or starts from scratch if no model is found""" if os.path.exists(save_path): @@ -137,6 +138,7 @@ def load_pretrained_model(model, save_path='color_analysis/trained_model.pth', d model.eval() return model + def save_season_palette(predicted_season): palette = extract_colors(image=f'color_analysis/nonspecific-season-palettes/{predicted_season}-palette.jpg', palette_size=48) w, h = 48, 48 @@ -149,6 +151,7 @@ def save_season_palette(predicted_season): img.save(f"combined_demo/output-imgs/your-palette.jpg") + def main(): # Get the image path for the prediction parser = argparse.ArgumentParser(description="ResNet-18 Color Analysis")