|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import numpy as np |
| 4 | +from sklearn import datasets |
| 5 | +import matplotlib.pyplot as plt |
| 6 | + |
| 7 | +# 0) Prepare data |
| 8 | +np.random.seed(326636) |
| 9 | +X_numpy = np.random.rand(500, 1) * 10 |
| 10 | +X_numpy = np.sort(X_numpy, axis=0) |
| 11 | + |
| 12 | +# Logistic growth equation parameters |
| 13 | +L = 10 # The curve's maximum value |
| 14 | +k = 1 # The logistic growth rate or steepness of the curve |
| 15 | +x0 = 5 # The x-value of the sigmoid's midpoint |
| 16 | + |
| 17 | +# Generate y values using the logistic growth equation |
| 18 | +y_numpy = L / (1 + np.exp(-k * (X_numpy - x0))) |
| 19 | +noise = np.random.normal(0, 0.25, y_numpy.shape) |
| 20 | +y_numpy += noise |
| 21 | +y_numpy = np.abs(y_numpy) |
| 22 | + |
| 23 | +# # Plot y vs X |
| 24 | +# plt.scatter(X_numpy, y_numpy, color='red', label='Original data') |
| 25 | +# plt.xlabel('X') |
| 26 | +# plt.ylabel('y') |
| 27 | +# plt.title('Original Data') |
| 28 | +# plt.legend() |
| 29 | +# plt.show() |
| 30 | + |
| 31 | +# cast to float Tensor |
| 32 | +X = torch.from_numpy(X_numpy.astype(np.float32)) |
| 33 | +y = torch.from_numpy(y_numpy.astype(np.float32)) |
| 34 | +y = y.view(y.shape[0], 1) |
| 35 | + |
| 36 | +n_samples, n_features = X.shape |
| 37 | + |
| 38 | +# 1) Model |
| 39 | +# Linear model f = wx + b |
| 40 | +input_size = n_features |
| 41 | +output_size = 1 |
| 42 | + |
| 43 | +# Define a simple neural network with one hidden layer |
| 44 | +class NonlinearModel(nn.Module): |
| 45 | + def __init__(self, input_size, hidden_size, output_size): |
| 46 | + super(NonlinearModel, self).__init__() |
| 47 | + self.hidden = nn.Linear(input_size, hidden_size) |
| 48 | + self.relu = nn.ReLU() |
| 49 | + self.output = nn.Linear(hidden_size, output_size) |
| 50 | + |
| 51 | + def forward(self, x): |
| 52 | + x = self.hidden(x) |
| 53 | + x = self.relu(x) |
| 54 | + x = self.output(x) |
| 55 | + return x |
| 56 | + |
| 57 | +hidden_size = 100 # You can adjust the hidden layer size as needed |
| 58 | +model = NonlinearModel(input_size, hidden_size, output_size) |
| 59 | + |
| 60 | +# 2) Loss and optimizer |
| 61 | +learning_rate = 0.01 |
| 62 | + |
| 63 | +criterion = nn.MSELoss() |
| 64 | +optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) |
| 65 | + |
| 66 | +# 3) Training loop |
| 67 | +num_epochs = 10000 |
| 68 | +loss_values = [] |
| 69 | + |
| 70 | +for epoch in range(num_epochs): |
| 71 | + # Forward pass and loss |
| 72 | + y_predicted = model(X) |
| 73 | + loss = criterion(y_predicted, y) |
| 74 | + |
| 75 | + # Backward pass and update |
| 76 | + loss.backward() |
| 77 | + optimizer.step() |
| 78 | + |
| 79 | + # zero grad before new step |
| 80 | + optimizer.zero_grad() |
| 81 | + |
| 82 | + # Store loss value for plotting |
| 83 | + loss_values.append(loss.item()) |
| 84 | + |
| 85 | + if (epoch+1) % 500 == 0: |
| 86 | + print(f'epoch: {epoch+1}, loss = {loss.item():.4f}') |
| 87 | + |
| 88 | +# Create a figure with two subplots |
| 89 | +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) |
| 90 | + |
| 91 | +# Plot loss vs. epoch starting from epoch 500 on the first subplot |
| 92 | +ax1.plot(range(500, num_epochs), loss_values[500:], label='Loss') |
| 93 | +ax1.set_xlabel('Epoch') |
| 94 | +ax1.set_ylabel('Loss') |
| 95 | +ax1.set_title('Loss vs. Epoch (Starting from Epoch 500)') |
| 96 | +ax1.legend() |
| 97 | + |
| 98 | +# Plot original data and model predictions on the second subplot |
| 99 | +X_test = np.linspace(0.01, 10, 200).reshape(-1, 1).astype(np.float32) |
| 100 | +X_test_tensor = torch.from_numpy(X_test) |
| 101 | +predicted_test = model(X_test_tensor).detach().numpy() |
| 102 | + |
| 103 | +ax2.plot(X_numpy, y_numpy, 'ro', label='Original data') |
| 104 | +ax2.plot(X_test, predicted_test, 'bo', label='Model predictions') |
| 105 | +ax2.set_xlabel('X') |
| 106 | +ax2.set_ylabel('y') |
| 107 | +ax2.set_title('Original Data vs. Model Predictions') |
| 108 | +ax2.legend() |
| 109 | + |
| 110 | +# Show the figure |
| 111 | +plt.tight_layout() |
| 112 | +plt.show() |
0 commit comments