Skip to content

Commit 2bdfedf

Browse files
committed
improved error handling and fixed tests
1 parent 2d26cc9 commit 2bdfedf

File tree

2 files changed

+49
-41
lines changed

2 files changed

+49
-41
lines changed

src/org/soulspace/qclojure/ml/application/training.clj

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
(ns org.soulspace.qclojure.ml.application.training
22
"Training algorithms and cost functions for quantum machine learning"
33
(:require [clojure.string :as str]
4+
[clojure.spec.alpha :as s]
45
[fastmath.core :as m]
56
[org.soulspace.qclojure.domain.state :as state]
67
[org.soulspace.qclojure.domain.circuit :as circuit]
@@ -11,6 +12,12 @@
1112
[org.soulspace.qclojure.application.algorithm.variational-algorithm :as va]
1213
[org.soulspace.qclojure.ml.application.encoding :as encoding]))
1314

15+
;; Specs for training data validation
16+
(s/def ::feature-vector (s/coll-of number? :kind vector?))
17+
(s/def ::features (s/coll-of ::feature-vector :kind vector?))
18+
(s/def ::labels (s/coll-of int? :kind vector?))
19+
(s/def ::training-data (s/keys :req-un [::features ::labels]))
20+
1421
;; QML-specific cost functions and loss function library
1522

1623
(defn cross-entropy-loss
@@ -113,6 +120,15 @@
113120
Cost value (real number)"
114121
[parameters ansatz-fn features labels backend & {:keys [options] :or {options {}}}]
115122
(try
123+
;; Input validation
124+
(when (or (nil? parameters)
125+
(not (vector? parameters))
126+
(empty? parameters)
127+
(empty? features)
128+
(empty? labels)
129+
(not= (count features) (count labels)))
130+
(throw (ex-info "Invalid inputs" {:parameters parameters :features features :labels labels})))
131+
116132
(let [num-samples (count features)
117133
loss-function (:loss-function options :cross-entropy)
118134
regularization (:regularization options :none)

test/org/soulspace/qclojure/ml/application/training_test.clj

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,15 @@
3535
(is (< result 100.0) "Cost should be reasonable for test data")))
3636

3737
(testing "Error handling"
38-
(let [invalid-params []
38+
(let [invalid-params [] ; Empty params should trigger error handling in training
3939
ansatz-fn (ansatz/hardware-efficient-ansatz 2 1)
4040
valid-features [[0.5 0.3]]
4141
valid-labels [0]
42-
backend (sim/create-simulator)
43-
44-
result (training/classification-cost invalid-params ansatz-fn valid-features valid-labels backend)]
42+
backend (sim/create-simulator)]
4543

46-
;; The current implementation returns 1000.0 on error
47-
(is (= 1000.0 result) "Should return high cost on error"))))
44+
;; Test with empty parameters - should catch error and return 1000.0
45+
(is (= 1000.0 (training/classification-cost invalid-params ansatz-fn valid-features valid-labels backend))
46+
"Should return high cost on error"))))
4847

4948
(deftest test-parameter-shift-gradient-validation
5049
(testing "Parameter shift gradient validation using optimization namespace"
@@ -99,58 +98,51 @@
9998
;; Property-based tests
10099
(deftest test-cost-function-properties
101100
(testing "Cost function properties"
102-
(let [test-cases (take 5 (gen/sample
101+
;; Generate test cases with properly sized parameters for 2-qubit, 1-layer ansatz
102+
(let [correct-param-gen (gen/vector (gen/double* {:min -3.14 :max 3.14 :NaN? false :infinite? false}) 6 6)
103+
test-cases (take 5 (gen/sample
103104
(gen/hash-map
104-
:parameters parameter-vector-gen
105+
:parameters correct-param-gen
105106
:features feature-matrix-gen
106107
:labels binary-labels-gen)))]
107108

108109
(doseq [test-case test-cases]
109110
(when (= (count (:features test-case)) (count (:labels test-case)))
110111
(let [ansatz-fn (ansatz/hardware-efficient-ansatz 2 1)
111112
backend (sim/create-simulator)
112-
cost (try
113-
(training/classification-cost
114-
(:parameters test-case)
115-
ansatz-fn
116-
(:features test-case)
117-
(:labels test-case)
118-
backend)
119-
(catch Exception _ 1000.0))]
113+
cost (training/classification-cost
114+
(:parameters test-case)
115+
ansatz-fn
116+
(:features test-case)
117+
(:labels test-case)
118+
backend)]
120119

121120
(is (number? cost) "Cost should be numeric")
122-
(is (>= cost 0.0) "Cost should be non-negative")))))))
121+
(is (>= cost 0.0) "Cost should be non-negative")
122+
(is (< cost 1000.0) "Cost should be reasonable for valid inputs")))))))
123123

124124
;; Error handling tests
125125
(deftest test-error-handling
126126
(testing "Various error conditions"
127127
(let [valid-params [0.1 0.2 0.3 0.4 0.5 0.6]
128128
ansatz-fn (ansatz/hardware-efficient-ansatz 2 1)
129-
backend (sim/create-simulator)
130-
131-
; Test mismatched features and labels
132-
mismatched-result (try
133-
(training/classification-cost
134-
valid-params ansatz-fn
135-
[[0.5 0.3]] [0 1] backend) ; 1 feature, 2 labels
136-
(catch Exception _ 1000.0))
137-
138-
; Test empty data
139-
empty-result (try
140-
(training/classification-cost
141-
valid-params ansatz-fn [] [] backend)
142-
(catch Exception _ 1000.0))
143-
144-
; Test nil inputs
145-
nil-result (try
146-
(training/classification-cost
147-
nil ansatz-fn [[0.5]] [0] backend)
148-
(catch Exception _ 1000.0))]
129+
backend (sim/create-simulator)]
130+
131+
;; Test mismatched features and labels
132+
(is (= 1000.0 (training/classification-cost
133+
valid-params ansatz-fn
134+
[[0.5 0.3]] [0 1] backend)) ; 1 feature, 2 labels
135+
"Should handle mismatched data")
136+
137+
;; Test empty data
138+
(is (= 1000.0 (training/classification-cost
139+
valid-params ansatz-fn [] [] backend))
140+
"Should handle empty data")
149141

150-
; Current implementation returns 1000.0 on all errors
151-
(is (= 1000.0 mismatched-result) "Should handle mismatched data")
152-
(is (= 1000.0 empty-result) "Should handle empty data")
153-
(is (= 1000.0 nil-result) "Should handle nil inputs"))))
142+
;; Test nil inputs - these should be caught by the try-catch in classification-cost
143+
(is (= 1000.0 (training/classification-cost
144+
nil ansatz-fn [[0.5]] [0] backend))
145+
"Should handle nil inputs"))))
154146

155147
;; Rich comment block for REPL testing
156148
(comment

0 commit comments

Comments
 (0)