|
| 1 | +// |
| 2 | +// LogitsWarperTests.swift |
| 3 | +// |
| 4 | +// Created by Jan Krukowski on 09/12/2023. |
| 5 | +// |
| 6 | + |
| 7 | +import XCTest |
| 8 | +import CoreML |
| 9 | +@testable import TensorUtils |
| 10 | + |
| 11 | +final class LogitsWarperTests: XCTestCase { |
| 12 | + private let accuracy: Float = 0.00001 |
| 13 | + |
| 14 | + func testTemperatureLogitsWarper() { |
| 15 | + let result1 = TemperatureLogitsWarper(temperature: 0.0)([]) |
| 16 | + XCTAssertTrue(result1.indexes.isEmpty) |
| 17 | + XCTAssertTrue(result1.logits.isEmpty) |
| 18 | + |
| 19 | + let result2 = TemperatureLogitsWarper(temperature: 1.0)([]) |
| 20 | + XCTAssertTrue(result2.indexes.isEmpty) |
| 21 | + XCTAssertTrue(result2.logits.isEmpty) |
| 22 | + |
| 23 | + let result3 = TemperatureLogitsWarper(temperature: 1.0)([2.0, 1.0]) |
| 24 | + XCTAssertEqual(result3.indexes, [0, 1]) |
| 25 | + XCTAssertEqual(result3.logits, [2.0, 1.0], accuracy: accuracy) |
| 26 | + |
| 27 | + let result4 = TemperatureLogitsWarper(temperature: 2.0)([2.0, 1.0]) |
| 28 | + XCTAssertEqual(result4.indexes, [0, 1]) |
| 29 | + XCTAssertEqual(result4.logits, [1.0, 0.5], accuracy: accuracy) |
| 30 | + |
| 31 | + let result5 = TemperatureLogitsWarper(temperature: 0.5)([2.0, 1.0]) |
| 32 | + XCTAssertEqual(result5.indexes, [0, 1]) |
| 33 | + XCTAssertEqual(result5.logits, [4.0, 2.0], accuracy: accuracy) |
| 34 | + } |
| 35 | + |
| 36 | + func testTopKLogitsWarper() { |
| 37 | + let result1 = TopKLogitsWarper(k: 0)([]) |
| 38 | + XCTAssertTrue(result1.indexes.isEmpty) |
| 39 | + XCTAssertTrue(result1.logits.isEmpty) |
| 40 | + |
| 41 | + let result2 = TopKLogitsWarper(k: 3)([]) |
| 42 | + XCTAssertTrue(result2.indexes.isEmpty) |
| 43 | + XCTAssertTrue(result2.logits.isEmpty) |
| 44 | + |
| 45 | + let result3 = TopKLogitsWarper(k: 3)([2.0, 1.0]) |
| 46 | + XCTAssertEqual(result3.indexes, [0, 1]) |
| 47 | + XCTAssertEqual(result3.logits, [2.0, 1.0], accuracy: accuracy) |
| 48 | + |
| 49 | + let result4 = TopKLogitsWarper(k: 3)([2.0, 1.0, 3.0]) |
| 50 | + XCTAssertEqual(result4.indexes, [2, 0, 1]) |
| 51 | + XCTAssertEqual(result4.logits, [3.0, 2.0, 1.0], accuracy: accuracy) |
| 52 | + |
| 53 | + let result5 = TopKLogitsWarper(k: 4)([2.0, 1.0, 3.0, -1.0, 123.0, 0.0]) |
| 54 | + XCTAssertEqual(result5.indexes, [4, 2, 0, 1]) |
| 55 | + XCTAssertEqual(result5.logits, [123.0, 3.0, 2.0, 1.0], accuracy: accuracy) |
| 56 | + } |
| 57 | + |
| 58 | + func testTopPLogitsWarper() { |
| 59 | + let result1 = TopPLogitsWarper(p: 0.99)([]) |
| 60 | + XCTAssertTrue(result1.indexes.isEmpty) |
| 61 | + XCTAssertTrue(result1.logits.isEmpty) |
| 62 | + |
| 63 | + let result2 = TopPLogitsWarper(p: 0.99)((0 ..< 10).map { Float($0) }) |
| 64 | + XCTAssertEqual(result2.indexes, [9, 8, 7, 6, 5]) |
| 65 | + XCTAssertEqual(result2.logits, [9.0, 8.0, 7.0, 6.0, 5.0], accuracy: accuracy) |
| 66 | + |
| 67 | + let result3 = TopPLogitsWarper(p: 0.95)((0 ..< 10).map { Float($0) }) |
| 68 | + XCTAssertEqual(result3.indexes, [9, 8, 7]) |
| 69 | + XCTAssertEqual(result3.logits, [9.0, 8.0, 7.0], accuracy: accuracy) |
| 70 | + |
| 71 | + let result4 = TopPLogitsWarper(p: 0.6321493)((0 ..< 10).map { Float($0) }) |
| 72 | + XCTAssertEqual(result4.indexes, [9, 8]) |
| 73 | + XCTAssertEqual(result4.logits, [9.0, 8.0], accuracy: accuracy) |
| 74 | + } |
| 75 | + |
| 76 | + func testLogitsProcessor() { |
| 77 | + let processor1 = LogitsProcessor(logitsWarpers: []) |
| 78 | + let result1 = processor1([]) |
| 79 | + XCTAssertTrue(result1.indexes.isEmpty) |
| 80 | + XCTAssertTrue(result1.logits.isEmpty) |
| 81 | + |
| 82 | + let processor2 = LogitsProcessor(logitsWarpers: []) |
| 83 | + let result2 = processor2([2.0, 1.0]) |
| 84 | + XCTAssertEqual(result2.indexes, [0, 1]) |
| 85 | + XCTAssertEqual(result2.logits, [2.0, 1.0], accuracy: accuracy) |
| 86 | + |
| 87 | + let processor3 = LogitsProcessor( |
| 88 | + logitsWarpers: [TopKLogitsWarper(k: 3)] |
| 89 | + ) |
| 90 | + let result3 = processor3([2.0, 1.0, 3.0, -5.0]) |
| 91 | + XCTAssertEqual(result3.indexes, [2, 0, 1]) |
| 92 | + XCTAssertEqual(result3.logits, [3.0, 2.0, 1.0], accuracy: accuracy) |
| 93 | + |
| 94 | + let processor4 = LogitsProcessor( |
| 95 | + logitsWarpers: [TopKLogitsWarper(k: 3), TopPLogitsWarper(p: 0.99)] |
| 96 | + ) |
| 97 | + let result4 = processor4([2.0, 1.0, 3.0, -5.0, -23.0, 12.5]) |
| 98 | + XCTAssertEqual(result4.indexes, [0]) |
| 99 | + XCTAssertEqual(result4.logits, [12.5], accuracy: accuracy) |
| 100 | + } |
| 101 | +} |
0 commit comments