Skip to content

Commit

Permalink
formatting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
mmikami123 committed Dec 7, 2024
1 parent 921cb63 commit a68d170
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion color_analysis/color_analysis_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def train_model(model, train_loader, criterion, optimizer, num_epochs=10, device
torch.save(model.state_dict(), save_path)


# Testing Function
def test_model(model, test_loader, device='cpu'):
"""Evaluate the model on the test set and return accuracy"""
# Model is set to evaluation mode
Expand All @@ -120,6 +119,7 @@ def test_model(model, test_loader, device='cpu'):
test_accuracy = correct / total * 100
return test_loss, test_accuracy


def predict_image(model, image_path, class_labels, transform, device='cpu'):
"""Predicts the class of a single image"""
model.eval()
Expand All @@ -130,13 +130,15 @@ def predict_image(model, image_path, class_labels, transform, device='cpu'):
_, pred = torch.max(outputs, 1)
return class_labels[pred.item()]


def load_pretrained_model(model, save_path='color_analysis/trained_model.pth', device='cpu'):
"""Loads a previously trained model or starts from scratch if no model is found"""
if os.path.exists(save_path):
model.load_state_dict(torch.load(save_path, map_location=device, weights_only=True))
model.eval()
return model


def save_season_palette(predicted_season):
palette = extract_colors(image=f'color_analysis/nonspecific-season-palettes/{predicted_season}-palette.jpg', palette_size=48)
w, h = 48, 48
Expand All @@ -149,6 +151,7 @@ def save_season_palette(predicted_season):

img.save(f"combined_demo/output-imgs/your-palette.jpg")


def main():
# Get the image path for the prediction
parser = argparse.ArgumentParser(description="ResNet-18 Color Analysis")
Expand Down

0 comments on commit a68d170

Please sign in to comment.