Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 257 additions & 22 deletions pkg/data/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"sync"
"time"

"github.com/iancoleman/strcase"

Expand Down Expand Up @@ -296,7 +298,24 @@ func (u *Unmarshaler) setDefaultValue(field reflect.Value, defaultVal string) er
}
field.SetBool(b)
default:
return fmt.Errorf("unsupported default value type: %v", field.Kind())
// Handle special types that don't match basic kinds
switch field.Type() {
case reflect.TypeOf(time.Duration(0)):
duration, err := time.ParseDuration(defaultVal)
if err != nil {
return fmt.Errorf("cannot parse default duration %q: %w", defaultVal, err)
}
field.Set(reflect.ValueOf(duration))
case reflect.TypeOf(time.Time{}):
// Try multiple time formats for default values
parsedTime, err := parseTimeValue(defaultVal, "")
if err != nil {
return fmt.Errorf("cannot parse default time %q: %w", defaultVal, err)
}
field.Set(reflect.ValueOf(parsedTime))
default:
return fmt.Errorf("unsupported default value type: %v", field.Type())
}
}

return nil
Expand All @@ -312,7 +331,7 @@ func (u *Unmarshaler) unmarshalValue(ctx context.Context, val format.Value, fiel
case *numberData:
return u.unmarshalNumber(v, field)
case *stringData:
return u.unmarshalString(ctx, v, field)
return u.unmarshalString(ctx, v, field, structField)
case Array:
if field.Type().Implements(reflect.TypeOf((*format.Value)(nil)).Elem()) {
field.Set(reflect.ValueOf(v))
Expand All @@ -336,37 +355,225 @@ func (u *Unmarshaler) unmarshalValue(ctx context.Context, val format.Value, fiel
}
}

// parseInstillTag parses the instill tag and returns field name, format, pattern, and other attributes
func parseInstillTag(tag string) (fieldName, format, pattern string, attributes map[string]string) {
attributes = make(map[string]string)
if tag == "" {
return
}

// First, extract the field name (everything before the first comma)
firstCommaIdx := strings.Index(tag, ",")
if firstCommaIdx == -1 {
fieldName = tag
return
}

fieldName = tag[:firstCommaIdx]
remaining := tag[firstCommaIdx+1:]

// Parse the remaining attributes using a simple approach that handles patterns better
parts := strings.Split(remaining, ",")

for i := 0; i < len(parts); i++ {
part := strings.TrimSpace(parts[i])
if part == "" {
continue
}

switch {
case strings.HasPrefix(part, "default="):
attributes["default"] = strings.TrimPrefix(part, "default=")
case strings.HasPrefix(part, "pattern="):
// For patterns, we may need to rejoin parts if the pattern contains commas
patternValue := strings.TrimPrefix(part, "pattern=")
// Check if this looks like an incomplete regex (missing closing bracket/paren)
if strings.Contains(patternValue, "(") && !strings.Contains(patternValue, ")") && i+1 < len(parts) {
// Likely a pattern split by comma, rejoin with next parts until we find a closing paren or end
for j := i + 1; j < len(parts); j++ {
patternValue += "," + parts[j]
if strings.Contains(parts[j], ")") {
i = j // Skip the parts we've consumed
break
}
}
}
pattern = patternValue
case strings.HasPrefix(part, "format="):
format = strings.TrimPrefix(part, "format=")
case strings.Contains(part, "/") && !strings.Contains(part, "="):
// Legacy format specification without "format=" prefix
format = part
}
}

return
}

// validatePattern validates a string against a regex pattern
func validatePattern(value, pattern string) error {
if pattern == "" {
return nil
}

// Unescape the pattern (convert \\. to \.)
unescapedPattern := strings.ReplaceAll(pattern, "\\\\", "\\")

regex, err := regexp.Compile(unescapedPattern)
if err != nil {
return fmt.Errorf("invalid pattern %q: %w", pattern, err)
}

if !regex.MatchString(value) {
return fmt.Errorf("value %q does not match pattern %q", value, pattern)
}

return nil
}

// parseTimeValue parses a time string using appropriate formats based on the format hint
func parseTimeValue(timeStr, format string) (time.Time, error) {
var timeFormats []string

// If format is "date-time" or similar, prioritize RFC3339 formats
if format == "date-time" || format == "datetime" {
timeFormats = []string{
time.RFC3339,
time.RFC3339Nano,
}
} else {
// Try multiple time formats
timeFormats = []string{
time.RFC3339,
time.RFC3339Nano,
"2006-01-02T15:04:05Z07:00",
"2006-01-02 15:04:05",
"2006-01-02",
}
}

for _, timeFormat := range timeFormats {
if parsedTime, err := time.Parse(timeFormat, timeStr); err == nil {
return parsedTime, nil
}
}

return time.Time{}, fmt.Errorf("unable to parse time string with any supported format")
}

// isFileType checks if a type is a file-related format type
func isFileType(t reflect.Type) bool {
fileTypes := []reflect.Type{
reflect.TypeOf((*format.Image)(nil)).Elem(),
reflect.TypeOf((*format.Audio)(nil)).Elem(),
reflect.TypeOf((*format.Video)(nil)).Elem(),
reflect.TypeOf((*format.Document)(nil)).Elem(),
reflect.TypeOf((*format.File)(nil)).Elem(),
}

for _, fileType := range fileTypes {
if t == fileType {
return true
}
}
return false
}

// handleTimePointer handles marshaling of time pointer types
func handleTimePointer(v reflect.Value) (format.Value, bool) {
elemType := v.Type().Elem()
switch elemType {
case reflect.TypeOf(time.Time{}):
timeVal := v.Interface().(*time.Time)
return NewString(timeVal.Format(time.RFC3339)), true
case reflect.TypeOf(time.Duration(0)):
durationVal := v.Interface().(*time.Duration)
return NewString(durationVal.String()), true
}
return nil, false
}

// unmarshalString handles unmarshaling of String values.
func (u *Unmarshaler) unmarshalString(ctx context.Context, v format.String, field reflect.Value) error {
func (u *Unmarshaler) unmarshalString(ctx context.Context, v format.String, field reflect.Value, structField reflect.StructField) error {
stringValue := v.String()

// Parse instill tag for validation rules
_, _, pattern, _ := parseInstillTag(structField.Tag.Get("instill"))

// Validate against pattern if specified (applies to all string fields)
if err := validatePattern(stringValue, pattern); err != nil {
return fmt.Errorf("pattern validation failed: %w", err)
}

switch field.Kind() {
case reflect.String:
field.SetString(v.String())
field.SetString(stringValue)
case reflect.Ptr:
if field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
return u.unmarshalString(ctx, v, field.Elem())
return u.unmarshalString(ctx, v, field.Elem(), structField)
default:
switch field.Type() {
// Handle time.Duration
case reflect.TypeOf(time.Duration(0)):
// Parse instill tag for parsing hints
_, _, pattern, _ := parseInstillTag(structField.Tag.Get("instill"))

// If pattern suggests seconds format, parse as seconds
if pattern != "" && strings.Contains(pattern, "s$") {
// Pattern suggests seconds format like "3600s" or "3600.5s"
// Remove the 's' suffix and parse as float, then convert to duration
if strings.HasSuffix(stringValue, "s") {
secondsStr := strings.TrimSuffix(stringValue, "s")
seconds, err := strconv.ParseFloat(secondsStr, 64)
if err != nil {
return fmt.Errorf("cannot parse seconds value %q: %w", secondsStr, err)
}
duration := time.Duration(seconds * float64(time.Second))
field.Set(reflect.ValueOf(duration))
} else {
return fmt.Errorf("duration string %q does not end with 's' as required by pattern", stringValue)
}
} else {
// No pattern or different pattern, use standard Go duration parsing
duration, err := time.ParseDuration(stringValue)
if err != nil {
return fmt.Errorf("cannot unmarshal string %q into time.Duration: %w", stringValue, err)
}
field.Set(reflect.ValueOf(duration))
}
// Handle time.Time
case reflect.TypeOf(time.Time{}):
// Parse instill tag for format specification
_, format, _, _ := parseInstillTag(structField.Tag.Get("instill"))

// If the string is a URL, create a file from the URL
case reflect.TypeOf((*format.Image)(nil)).Elem(),
reflect.TypeOf((*format.Audio)(nil)).Elem(),
reflect.TypeOf((*format.Video)(nil)).Elem(),
reflect.TypeOf((*format.Document)(nil)).Elem(),
reflect.TypeOf((*format.File)(nil)).Elem():
f, err := u.createFileFromURL(ctx, field.Type(), v.String())
if err == nil {
field.Set(reflect.ValueOf(f))
return nil
parsedTime, err := parseTimeValue(stringValue, format)
if err != nil {
return fmt.Errorf("cannot unmarshal string %q into time.Time: %w", stringValue, err)
}
// If URL creation fails, return a helpful error message
return fmt.Errorf("cannot unmarshal string into %v: expected valid URL, not base64 string: %w", field.Type(), err)
case reflect.TypeOf(v), reflect.TypeOf((*format.String)(nil)).Elem():
field.Set(reflect.ValueOf(v))
case reflect.TypeOf((*format.Value)(nil)).Elem():
field.Set(reflect.ValueOf(v))
field.Set(reflect.ValueOf(parsedTime))

default:
// Try to create file from URL for media/document types
if isFileType(field.Type()) {
f, err := u.createFileFromURL(ctx, field.Type(), v.String())
if err == nil {
field.Set(reflect.ValueOf(f))
return nil
}
// If URL creation fails, return a helpful error message
return fmt.Errorf("cannot unmarshal string into %v: expected valid URL, not base64 string: %w", field.Type(), err)
}

// Handle format.Value types
if field.Type() == reflect.TypeOf(v) ||
field.Type() == reflect.TypeOf((*format.String)(nil)).Elem() ||
field.Type() == reflect.TypeOf((*format.Value)(nil)).Elem() {
field.Set(reflect.ValueOf(v))
return nil
}

return fmt.Errorf("cannot unmarshal String into %v", field.Type())
}
}
Expand Down Expand Up @@ -418,10 +625,18 @@ func (u *Unmarshaler) unmarshalNumber(v format.Number, field reflect.Value) erro
case reflect.Float32, reflect.Float64:
field.SetFloat(v.Float64())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
// Special handling for time.Duration - should only accept string format
if field.Type() == reflect.TypeOf(time.Duration(0)) {
return fmt.Errorf("cannot unmarshal Number into time.Duration: use string format like \"60s\"")
}
field.SetInt(int64(v.Integer()))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.SetUint(uint64(v.Integer()))
case reflect.Ptr:
// Special handling for *time.Duration - should only accept string format
if field.Type().Elem() == reflect.TypeOf(time.Duration(0)) {
return fmt.Errorf("cannot unmarshal Number into *time.Duration: use string format like \"60s\"")
}
if field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
Expand Down Expand Up @@ -733,6 +948,13 @@ func (m *Marshaler) marshalValue(v reflect.Value) (format.Value, error) {
}
}

// Handle special pointer types before dereferencing
if v.Kind() == reflect.Ptr && !v.IsNil() {
if timeVal, ok := handleTimePointer(v); ok {
return timeVal, nil
}
}

// Dereference pointer if necessary
for v.Kind() == reflect.Ptr {
if v.IsNil() {
Expand All @@ -743,7 +965,15 @@ func (m *Marshaler) marshalValue(v reflect.Value) (format.Value, error) {

switch v.Kind() {
case reflect.Struct:
return m.marshalStruct(v)
// Handle special struct types before generic struct marshaling
switch v.Type() {
case reflect.TypeOf(time.Time{}):
// Marshal time.Time as RFC3339 string
timeVal := v.Interface().(time.Time)
return NewString(timeVal.Format(time.RFC3339)), nil
default:
return m.marshalStruct(v)
}
case reflect.Map:
if v.Type().Key().Kind() != reflect.String {
return nil, fmt.Errorf("map key must be string type")
Expand All @@ -754,6 +984,11 @@ func (m *Marshaler) marshalValue(v reflect.Value) (format.Value, error) {
case reflect.Float32, reflect.Float64:
return NewNumberFromFloat(v.Float()), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
// Handle time.Duration before generic int64 handling
if v.Type() == reflect.TypeOf(time.Duration(0)) {
durationVal := v.Interface().(time.Duration)
return NewString(durationVal.String()), nil
}
return NewNumberFromInteger(int(v.Int())), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return NewNumberFromInteger(int(v.Uint())), nil
Expand Down
Loading
Loading