Skip to content

Commit

Permalink
Improve random sorting algorithm (stashapp#4246)
Browse files Browse the repository at this point in the history
  • Loading branch information
DingDongSoLong4 authored and halkeye committed Sep 1, 2024
1 parent d389ede commit 45e255c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ GO_BUILD_FLAGS := $(GO_BUILD_FLAGS)

# set GO_BUILD_TAGS environment variable to any extra build tags required
GO_BUILD_TAGS := $(GO_BUILD_TAGS)
GO_BUILD_TAGS += sqlite_stat4
GO_BUILD_TAGS += sqlite_stat4 sqlite_math_functions

# set STASH_NOLEGACY environment variable or uncomment to disable legacy browser support
# STASH_NOLEGACY := true
Expand Down
31 changes: 19 additions & 12 deletions pkg/sqlite/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import (
"github.com/stashapp/stash/pkg/models"
)

var randomSortFloat = rand.Float64()

func selectAll(tableName string) string {
idColumn := getColumn(tableName, "*")
return "SELECT " + idColumn + " FROM " + tableName + " "
Expand Down Expand Up @@ -66,16 +64,15 @@ func getSort(sort string, direction string, tableName string) string {
return " ORDER BY " + colName + " " + direction
case strings.HasPrefix(sort, randomSeedPrefix):
// seed as a parameter from the UI
// turn the provided seed into a float
seedStr := "0." + sort[len(randomSeedPrefix):]
seed, err := strconv.ParseFloat(seedStr, 32)
seedStr := sort[len(randomSeedPrefix):]
seed, err := strconv.ParseUint(seedStr, 10, 64)
if err != nil {
// fallback to default seed
seed = randomSortFloat
// fallback to a random seed
seed = rand.Uint64()
}
return getRandomSort(tableName, direction, seed)
case strings.Compare(sort, "random") == 0:
return getRandomSort(tableName, direction, randomSortFloat)
return getRandomSort(tableName, direction, rand.Uint64())
default:
colName := getColumn(tableName, sort)
if strings.Contains(sort, ".") {
Expand All @@ -92,11 +89,21 @@ func getSort(sort string, direction string, tableName string) string {
}
}

func getRandomSort(tableName string, direction string, seed float64) string {
// https://stackoverflow.com/a/24511461
func getRandomSort(tableName string, direction string, seed uint64) string {
// cap seed at 10^8
seed %= 1e8

colName := getColumn(tableName, "id")
randomSortString := strconv.FormatFloat(seed, 'f', 16, 32)
return " ORDER BY " + "(substr(" + colName + " * " + randomSortString + ", length(" + colName + ") + 2))" + " " + direction

// https://stackoverflow.com/questions/21949795#comment33255354_21949859
// p1 := 52959209
// p2 := 1047483763
// p3 := 2147483647
// n := <colName>
// ORDER BY ((n+seed)*(n+seed)*p1 + (n+seed)*p2) % p3
// since sqlite converts overflowing numbers to reals, a custom db function that uses uints with overflow should be faster,
// however in practice the overhead of calling a custom function vastly outweighs the benefits
return fmt.Sprintf(" ORDER BY mod((%[1]s + %[2]d) * (%[1]s + %[2]d) * 52959209 + (%[1]s + %[2]d) * 1047483763, 2147483647) %[3]s", colName, seed, direction)
}

func getCountSort(primaryTable, joinTable, primaryFK, direction string) string {
Expand Down

0 comments on commit 45e255c

Please sign in to comment.