Skip to content

Commit 6b55eca

Browse files
committed
Update plots
1 parent d009416 commit 6b55eca

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

07_1_nonlinear_regression.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# 0) Prepare data
88
np.random.seed(326636)
9-
X_numpy = np.random.rand(500, 1) * 10
9+
X_numpy = np.random.rand(1000, 1) * 10
1010
X_numpy = np.sort(X_numpy, axis=0)
1111

1212
# Logistic growth equation parameters
@@ -61,10 +61,10 @@ def forward(self, x):
6161
learning_rate = 0.01
6262

6363
criterion = nn.MSELoss()
64-
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
64+
optimizer = torch.optim.RAdam(model.parameters(), lr=learning_rate)
6565

6666
# 3) Training loop
67-
num_epochs = 20000
67+
num_epochs = 10000
6868
loss_values = []
6969

7070
for epoch in range(num_epochs):
@@ -100,8 +100,8 @@ def forward(self, x):
100100
X_test_tensor = torch.from_numpy(X_test)
101101
predicted_test = model(X_test_tensor).detach().numpy()
102102

103-
ax2.plot(X_numpy, y_numpy, 'ro', label='Original data')
104-
ax2.plot(X_test, predicted_test, 'bo', label='Model predictions')
103+
ax2.plot(X_numpy, y_numpy, 'ro', label='Original data', markersize=2)
104+
ax2.plot(X_test, predicted_test, 'b-', label='Model predictions', linewidth=3)
105105
ax2.set_xlabel('X')
106106
ax2.set_ylabel('y')
107107
ax2.set_title('Original Data vs. Model Predictions')

0 commit comments

Comments
 (0)