diff --git a/type/interfaces.go b/type/interfaces.go index fb86197..7f33604 100644 --- a/type/interfaces.go +++ b/type/interfaces.go @@ -11,6 +11,10 @@ type ValueContainer interface { HasValue() bool } +type Recoverable[T any] interface { + CatchUnwrap(T) +} + type Unwrappable[T any] interface { Expect(string) T // panics with a provided custom message Unwrap() T // panics with a generic message @@ -30,7 +34,8 @@ type Optionaler[T any] interface { IsNone() bool OkOr(error) Result[T] OkOrElse(func() error) Result[T] - Unwrappable[T] + Optioner[T] + OptionalerMarker } type Resulter[T any] interface { @@ -38,11 +43,24 @@ type Resulter[T any] interface { IsErr() bool Ok() Optional[T] Err() Optional[error] - Unwrappable[T] + Optioner[T] + ResulterMarker } +// Marker interfaces to help type matching +type ( + ResulterMarker interface { + Result() + } + OptionalerMarker interface { + Optional() + } +) + // Ensure compile time the interfaces are implemented var ( + _ OptionalerMarker = (*Optional[any])(nil) + _ ResulterMarker = (*Result[any])(nil) _ Optioner[any] = (*Optional[any])(nil) _ Optioner[any] = (*Result[any])(nil) _ Optionaler[any] = (*Optional[any])(nil) diff --git a/type/optional.go b/type/optional.go index 5fa0698..a84895f 100644 --- a/type/optional.go +++ b/type/optional.go @@ -12,6 +12,9 @@ import ( Assert "github.com/lbatuska/goutils/assert" ) +// Marker interface impl +func (opt Optional[T]) Optional() {} + // CTORS BEGIN func Some[T any](value T) Optional[T] { return Optional[T]{value, true} @@ -57,11 +60,14 @@ func (opt Optional[T]) Expect(msg string) T { panic(msg) } -func (opt Optional[T]) Unwrap() T { +func (opt *Optional[T]) Unwrap() T { + if opt == nil { + panic("Tried unwrapping an Optional that did not have a value!") + } if opt.present { return opt.value } - panic("Tried unwrapping an Optional that did not have a value!") + panic(opt) } func (opt *Optional[T]) UnwrapOr(val T) T { diff --git a/type/result.go b/type/result.go index 5cca7be..3167e00 100644 --- a/type/result.go +++ b/type/result.go @@ -12,6 +12,9 @@ import ( Assert "github.com/lbatuska/goutils/assert" ) +// Marker interface impl +func (res Result[T]) Result() {} + // CTORS BEGIN func Ok[T any](value T) Result[T] { return Result[T]{value: value, err: nil} @@ -56,11 +59,14 @@ func (res Result[T]) Expect(msg string) T { panic(msg) } -func (res Result[T]) Unwrap() T { +func (res *Result[T]) Unwrap() T { + if res == nil { + panic("Tried unwrapping a Result that had an error value!") + } if res.err == nil { return res.value } - panic("Tried unwrapping a Result that had an error value!") + panic(res) } func (res *Result[T]) UnwrapOr(val T) T { diff --git a/type/utils.go b/type/utils.go index 9f0ac61..cec49e0 100644 --- a/type/utils.go +++ b/type/utils.go @@ -1,5 +1,11 @@ package Type +import ( + "fmt" + "reflect" + "unsafe" +) + func Expect[T any](val Unwrappable[T], msg string) T { return val.Expect(msg) } @@ -42,3 +48,92 @@ func ResultWrapb[T any](err error, val T) Result[T] { func Ptr[T any](v T) *T { return &v } + +// meant to be used as defer Type.CatchUnwrap(Type.Ptr(&res)) or Type.CatchUnwrap(&res) if res is already a pointer +// where res is a pointer to an Option or Result returned by a function (initialized to not be nil) +// func X() (res *Optional[int]) { +// res = None[int]() +// defer CatchUnwrap(&res) +// +// // Some possibly unsafe unwrapping of values +// return res +// } +// === OR === +// func X() (res Optional[int]) { +// res = None[int]() +// defer Type.CatchUnwrap(Type.Ptr(&res)) +// +// // Some possibly unsafe unwrapping of values +// return res +// } + +func CatchUnwrap(ret interface{}) { + r := recover() + if r == nil { + return + } + vp := reflect.ValueOf(r) + if vp.Kind() != reflect.Pointer || vp.IsNil() { + panic(r) + } + + if _, ok := r.(OptionalerMarker); ok { + if setOptionalNone(ret) { + return + } + } + + if _, ok := r.(ResulterMarker); ok { + if setResultError(ret, fmt.Errorf("Tried to unwrap a failed result!")) { + return + } + } + + panic(r) +} + +func setOptionalNone(ret interface{}) bool { + v := reflect.ValueOf(ret) + if v.Kind() != reflect.Ptr || v.IsNil() { + return false + } + elem := v.Elem().Elem() + if elem.Kind() != reflect.Struct { + return false + } + presentField := elem.FieldByName("present") + if !presentField.IsValid() { + return false + } + if presentField.Kind() == reflect.Bool { + presentField = reflect.NewAt(presentField.Type(), + unsafe.Pointer(presentField.UnsafeAddr())).Elem() + presentField.SetBool(false) // Setting it to false (None) + } else { + return false + } + return true +} + +func setResultError(ret interface{}, err error) bool { + v := reflect.ValueOf(ret) + if v.Kind() != reflect.Ptr || v.IsNil() { + return false + } + elem := v.Elem().Elem() + if elem.Kind() != reflect.Struct { + return false + } + errField := elem.FieldByName("err") + if !errField.IsValid() { + return false + } + if errField.Kind() == reflect.Interface { + errField = reflect.NewAt(errField.Type(), + unsafe.Pointer(errField.UnsafeAddr())).Elem() + errField.Set(reflect.ValueOf(err)) // Setting the error + } else { + return false + } + return true +}