- 
                Notifications
    You must be signed in to change notification settings 
- Fork 219
Add Losses #129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
    
  
     Merged
                    Add Losses #129
Changes from 17 commits
      Commits
    
    
            Show all changes
          
          
            27 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      c57a2e7
              
                Merge pull request #3 from tensorflow/master
              
              
                JimClarke5 9cc2675
              
                Initial checkin to rebase to Initialziers to pick up changes to ndarr…
              
              
                JimClarke5 2508f5e
              
                Initial Checkin for losses
              
              
                JimClarke5 17e96b5
              
                Fix reshape in sparseCategoricalCrossentropy()
              
              
                JimClarke5 ee1c48a
              
                Apply various fixes to JavaDoc
              
              
                JimClarke5 287c96e
              
                Change Tuple to LossTuple
              
              
                JimClarke5 642069c
              
                Repair JavaDOx
              
              
                JimClarke5 249b651
              
                Fixed AllAxis to hanlde dynamic shape when static shape rank is unknown.
              
              
                JimClarke5 794cfdc
              
                change method name allAxis to allAxes
              
              
                JimClarke5 fb26c59
              
                change private method binaryCrossentropy to binaryCrossentropyHelper
              
              
                JimClarke5 928ef06
              
                Fixed squeezeOrExpandDimensions to make sure the updated labels, pred…
              
              
                JimClarke5 2bc54dd
              
                Fix JavaDoc,
              
              
                JimClarke5 951443b
              
                Fix unused imports and add @SuppressWarnings("unchecked") for casts.
              
              
                JimClarke5 ebac9e8
              
                Add copyright
              
              
                JimClarke5 d8f3254
              
                Add CastHelper and used that for all casts
              
              
                JimClarke5 02573b5
              
                Fix JavaDoc, change snake case to camel case.
              
              
                JimClarke5 0bf49fe
              
                Change class LossesImpl to LossesHelper
              
              
                JimClarke5 0eae9ee
              
                Remove commented out JavaDoc
              
              
                JimClarke5 b211937
              
                Changed method name from smoothLabelsBinaryX to smoothBinaryLabels,
              
              
                JimClarke5 3e0669e
              
                Fixed JavaDoc for labelSmoothing
              
              
                JimClarke5 914f16f
              
                Fixed JavaDoc to change label_smoothing to labelSmoothing.
              
              
                JimClarke5 7eefbb7
              
                Fix formatting
              
              
                JimClarke5 b87ad16
              
                replace label_smoothing with labelSmoothing.
              
              
                JimClarke5 c43cd21
              
                Add copyright to test cases
              
              
                JimClarke5 4d9fd24
              
                Fix copyright to attribute TensorFlow Authors.
              
              
                JimClarke5 d56d8d9
              
                Fix typo on broadcast in JavaDoc
              
              
                JimClarke5 744e324
              
                Fix typo on broadcast in JavaDoc
              
              
                JimClarke5 File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
        
          
          
            230 changes: 230 additions & 0 deletions
          
          230 
        
  tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java
  
  
      
      
   
        
      
      
    
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,230 @@ | ||
| /* | ||
| * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.tensorflow.framework.losses; | ||
|  | ||
| import org.tensorflow.Operand; | ||
| import org.tensorflow.framework.losses.impl.LossesHelper; | ||
| import org.tensorflow.op.Ops; | ||
| import org.tensorflow.types.family.TNumber; | ||
|  | ||
| import static org.tensorflow.framework.utils.CastHelper.cast; | ||
|  | ||
| /** | ||
| * Computes the cross-entropy loss between true labels and predicted labels. | ||
| * | ||
| * <p>Use this cross-entropy loss when there are only two label classes (assumed to be 0 and 1). For | ||
| * each example, there should be a single floating-point value per prediction. | ||
| * | ||
| * <p>Standalone usage: | ||
| * | ||
| * <pre> | ||
| * Operand<TFloat32> labels = | ||
| * tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}}); | ||
| * Operand<TFloat32> predictions = | ||
| * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); | ||
| * BinaryCrossentropy bce = new BinaryCrossentropy(tf); | ||
| * Operand<TFloat32> result = bce.call(labels, predictions); | ||
| * // produces 0.815 | ||
| * </pre> | ||
| * | ||
| * <p>Calling with sample weight: | ||
| * | ||
| * <pre> | ||
| * Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f}); | ||
| * Operand<TFloat32> result = bce.call(labels, predictions, sampleWeight); | ||
| * // produces 0.458f | ||
| * </pre> | ||
| * | ||
| * <p>Using <code>SUM</code> reduction type: | ||
| * | ||
| * <pre> | ||
| * BinaryCrossentropy bce = new BinaryCrossentropy(tf, Reduction.SUM); | ||
| * Operand<TFloat32> result = bce.call(labels, predictions); | ||
| * // produces 1.630f | ||
| * </pre> | ||
| * | ||
| * <p>Using <code>NONE</code> reduction type: | ||
| * | ||
| * <pre> | ||
| * BinaryCrossentropy bce = new BinaryCrossentropy(tf, Reduction.NONE); | ||
| * Operand<TFloat32> result = bce.call(labels, predictions); | ||
| * // produces [0.916f, 0.714f] | ||
| * </pre> | ||
| */ | ||
| public class BinaryCrossentropy extends Loss { | ||
| public static final boolean FROM_LOGITS_DEFAULT = false; | ||
| public static final float LABEL_SMOOTHING_DEFAULT = 0.0f; | ||
|  | ||
| private final boolean fromLogits; | ||
| private final float labelSmoothing; | ||
|  | ||
| /** | ||
| * Creates a Binary Crossentropy Loss using {@link Class#getSimpleName()} as the loss name, {@link | ||
| * #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing and a | ||
| * Loss Reduction of {@link Loss#REDUCTION_DEFAULT} | ||
| * | ||
| * @param tf the TensorFlow Ops | ||
| */ | ||
| public BinaryCrossentropy(Ops tf) { | ||
| this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); | ||
| } | ||
|  | ||
| /** | ||
| * Creates a Binary Crossentropy loss using {@link Class#getSimpleName()} as the loss name, {@link | ||
| * #FROM_LOGITS_DEFAULT} for fromLogits, and {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing | ||
| * | ||
| * @param tf the TensorFlow Ops | ||
| * @param reduction Type of Reduction to apply to the loss. | ||
| */ | ||
| public BinaryCrossentropy(Ops tf, Reduction reduction) { | ||
| this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction); | ||
| } | ||
|  | ||
| /** | ||
| * Creates a Binary Crossentropy loss using using {@link Class#getSimpleName()} as the loss name, | ||
| * labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT}, a reduction of {@link | ||
| * Loss#REDUCTION_DEFAULT}, | ||
| * | ||
| * @param tf the TensorFlow Ops | ||
| * @param fromLogits Whether to interpret predictions as a tensor of logit values | ||
| */ | ||
| public BinaryCrossentropy(Ops tf, boolean fromLogits) { | ||
| this(tf, null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); | ||
| } | ||
|  | ||
| /** | ||
| * Creates a Binary Crossentropy loss using labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT} a | ||
| * reduction of {@link Loss#REDUCTION_DEFAULT}. | ||
| * | ||
| * @param tf the TensorFlow Ops | ||
| * @param name the name of the loss | ||
| * @param fromLogits Whether to interpret predictions as a tensor of logit values | ||
| */ | ||
| public BinaryCrossentropy(Ops tf, String name, boolean fromLogits) { | ||
| this(tf, name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); | ||
| } | ||
|  | ||
| /** | ||
| * Creates a Binary Crossentropy loss using using {@link Class#getSimpleName()} as the loss name, | ||
| * and a reduction of {@link Loss#REDUCTION_DEFAULT}. | ||
| * | ||
| * @param tf the TensorFlow Ops | ||
| * @param fromLogits Whether to interpret predictions as a tensor of logit values | ||
| * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, | ||
| * compute the loss between the predicted labels and a smoothed version of the true labels, | ||
| * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing | ||
| * correspond to heavier smoothing. | ||
| */ | ||
| public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { | ||
| this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT); | ||
| } | ||
|  | ||
| /** | ||
| * Creates a Binary Crossentropy loss using a reduction of {@link Loss#REDUCTION_DEFAULT}. | ||
| * | ||
| * @param tf the TensorFlow Ops | ||
| * @param name the name of the loss | ||
| * @param fromLogits Whether to interpret predictions as a tensor of logit values | ||
| * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, | ||
| * compute the loss between the predicted labels and a smoothed version of the true labels, | ||
| * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing | ||
| * correspond to heavier smoothing. | ||
| */ | ||
| public BinaryCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { | ||
| this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT); | ||
| } | ||
|  | ||
| /** | ||
| * Creates a Binary Crossentropy loss | ||
| * | ||
| * @param tf the TensorFlow Ops | ||
| * @param fromLogits Whether to interpret predictions as a tensor of logit values | ||
| * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, | ||
| * compute the loss between the predicted labels and a smoothed version of the true labels, | ||
| * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing | ||
| * correspond to heavier smoothing. | ||
| * @param reduction Type of Reduction to apply to the loss. | ||
| */ | ||
| public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing, Reduction reduction) { | ||
| this(tf, null, fromLogits, labelSmoothing, reduction); | ||
| } | ||
|  | ||
| /** | ||
| * Creates a Binary Crossentropy loss | ||
| * | ||
| * @param tf the TensorFlow Ops | ||
| * @param name the name of the loss | ||
| * @param fromLogits Whether to interpret predictions as a tensor of logit values | ||
| * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, | ||
| * compute the loss between the predicted labels and a smoothed version of the true labels, | ||
| * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing | ||
| * correspond to heavier smoothing. | ||
| * @param reduction Type of Reduction to apply to the loss. | ||
| * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. | ||
| */ | ||
| public BinaryCrossentropy( | ||
| Ops tf, String name, boolean fromLogits, float labelSmoothing, Reduction reduction) { | ||
| super(tf, name, reduction); | ||
| if(labelSmoothing < 0 || labelSmoothing > 1) | ||
| throw new IllegalArgumentException("labelSmoothing must be >= 0. and <= 1, found " + labelSmoothing); | ||
| this.fromLogits = fromLogits; | ||
| this.labelSmoothing = labelSmoothing; | ||
| } | ||
|  | ||
| /** | ||
| * Generates an Operand that calculates the loss. | ||
| * | ||
| * If run in Graph mode, the computation will throw {@link org.tensorflow.exceptions.TFInvalidArgumentException} | ||
| * if the predictions values are outside the range o [0. to 1.]. In Eager Mode, this call | ||
| * will throw {@link IllegalArgumentException}, if the predictions values are outside the range o [0. to 1.] | ||
| * | ||
| * @param labels the truth values or labels | ||
| * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. | ||
| * @param sampleWeights Optional SampleWeights acts as a coefficient for the loss. If a scalar is | ||
| * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor | ||
| * of size [batch_size], then the total loss for each sample of the batch is rescaled by the | ||
| * corresponding element in the SampleWeights vector. If the shape of SampleWeights is | ||
| * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of | ||
| * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss | ||
| * functions reduce by 1 dimension, usually axis=-1.) | ||
| * @param <T> The data type of the predictions, sampleWeights and loss. | ||
| * @param <U> The data type of the labels. | ||
| * @return the loss | ||
| * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. | ||
| */ | ||
| @Override | ||
| public <T extends TNumber, U extends TNumber> Operand<T> call( | ||
| Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) { | ||
| Operand<T> lPredictions; | ||
| if (!fromLogits) { | ||
| // add predictions range check for 0 - 1 | ||
| lPredictions = | ||
| LossesHelper.rangeCheck( | ||
| getTF(), | ||
| "predictions range check [0-1]", | ||
| predictions, | ||
| cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()), | ||
| cast(getTF(), getTF().constant(1), predictions.asOutput().dataType())); | ||
|  | ||
| } else { | ||
| lPredictions = predictions; | ||
| } | ||
|  | ||
| Operand<T> losses = | ||
| Losses.binaryCrossentropy(getTF(), labels, lPredictions, fromLogits, labelSmoothing); | ||
| return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); | ||
| } | ||
| } | ||
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
label_smoothing->labelSmoothing, here and elsewhere in this file.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK