@@ -42,7 +42,7 @@ func AddAfterPackage(f *proto.Proto, v proto.Visitee) error {
4242 if inserted {
4343 return nil
4444 }
45- return errors .New ("could not find package statement" )
45+ return errors .New ("could not find proto package statement" )
4646}
4747
4848// Fallback logic, try and use import after a package and if that fails
@@ -118,7 +118,7 @@ func AddImports(f *proto.Proto, fallback bool, imports ...*proto.Import) (err er
118118 // recurse with the rest. (might be empty)
119119 return AddImports (f , false , imports [1 :]... )
120120 }
121- return errors .New ("unable to add import, no import statements found" )
121+ return errors .New ("unable to add proto import, no import statements found" )
122122}
123123
124124// NextUniqueID goes through the fields of the given Message and returns
@@ -180,10 +180,11 @@ func GetMessageByName(f *proto.Proto, name string) (node *proto.Message, err err
180180 },
181181 // return immediately iff found.
182182 func (* Cursor ) bool { return ! found })
183+
183184 if found {
184185 return
185186 }
186- return nil , errors .Errorf ("message %s not found" , name )
187+ return nil , errors .Errorf ("proto message %s not found" , name )
187188}
188189
189190// GetServiceByName returns the service with the given name or nil if not found.
@@ -192,8 +193,12 @@ func GetMessageByName(f *proto.Proto, name string) (node *proto.Message, err err
192193// f, _ := ParseProtoPath("foo.proto")
193194// s := GetServiceByName(f, "FooSrv")
194195// s.Name // "FooSrv"
195- func GetServiceByName (f * proto.Proto , name string ) (node * proto.Service , err error ) {
196- node , err = nil , nil
196+ func GetServiceByName (f * proto.Proto , name string ) (* proto.Service , error ) {
197+ var (
198+ node * proto.Service
199+ err error
200+ )
201+
197202 found := false
198203 Apply (f ,
199204 func (c * Cursor ) bool {
@@ -209,12 +214,15 @@ func GetServiceByName(f *proto.Proto, name string) (node *proto.Service, err err
209214 _ , ok := c .Node ().(* proto.Proto )
210215 return ok
211216 },
217+
212218 // return immediately iff found.
213- func (* Cursor ) bool { return ! found })
219+ func (* Cursor ) bool { return ! found },
220+ )
214221 if found {
215- return
222+ return node , err
216223 }
217- return nil , errors .Errorf ("service %s not found" , name )
224+
225+ return nil , errors .Errorf ("proto service %s not found" , name )
218226}
219227
220228// GetImportByPath returns the import with the given path or nil if not found.
@@ -223,9 +231,13 @@ func GetServiceByName(f *proto.Proto, name string) (node *proto.Service, err err
223231// f, _ := ParseProtoPath("foo.proto")
224232// s := GetImportByPath(f, "other.proto")
225233// s.FileName // "other.proto"
226- func GetImportByPath (f * proto.Proto , path string ) (node * proto.Import , err error ) {
234+ func GetImportByPath (f * proto.Proto , path string ) (* proto.Import , error ) {
235+ var (
236+ node * proto.Import
237+ err error
238+ )
239+
227240 found := false
228- node , err = nil , nil
229241 Apply (f ,
230242 func (c * Cursor ) bool {
231243 if i , ok := c .Node ().(* proto.Import ); ok {
@@ -240,12 +252,54 @@ func GetImportByPath(f *proto.Proto, path string) (node *proto.Import, err error
240252 _ , ok := c .Node ().(* proto.Proto )
241253 return ok
242254 },
255+
243256 // return immediately iff found.
244- func (* Cursor ) bool { return ! found })
257+ func (* Cursor ) bool { return ! found },
258+ )
245259 if found {
246- return
260+ return node , err
247261 }
248- return nil , errors .Errorf ("import %s not found" , path )
262+
263+ return nil , errors .Errorf ("proto import %s not found" , path )
264+ }
265+
266+ // GetFieldByName returns the field with the given name or nil if not found within a message.
267+ // Only traverses in proto.Message since they are the only nodes that contain fields:
268+ //
269+ // f, _ := ParseProtoPath("foo.proto")
270+ // m := GetMessageByName(f, "Foo")
271+ // f := GetFieldByName(m, "Bar")
272+ // f.Name // "Bar"
273+ func GetFieldByName (f * proto.Message , name string ) (* proto.NormalField , error ) {
274+ var (
275+ node * proto.NormalField
276+ err error
277+ )
278+
279+ found := false
280+ Apply (f ,
281+ func (c * Cursor ) bool {
282+ if m , ok := c .Node ().(* proto.NormalField ); ok {
283+ if m .Name == name {
284+ found = true
285+ node = m
286+ return false
287+ }
288+ // keep looking if we're in a Message
289+ return true
290+ }
291+ // keep looking while we're in a proto.Message.
292+ _ , ok := c .Node ().(* proto.Message )
293+ return ok
294+ },
295+ // return immediately iff found.
296+ func (* Cursor ) bool { return ! found },
297+ )
298+ if found {
299+ return node , err
300+ }
301+
302+ return nil , errors .Errorf ("proto field %s not found" , name )
249303}
250304
251305// HasMessage returns true if the given message is found in the given file.
@@ -277,3 +331,13 @@ func HasImport(f *proto.Proto, path string) bool {
277331 _ , err := GetImportByPath (f , path )
278332 return err == nil
279333}
334+
335+ func HasField (f * proto.Proto , messageName , field string ) bool {
336+ msg , err := GetMessageByName (f , messageName )
337+ if err != nil {
338+ return false
339+ }
340+
341+ _ , err = GetFieldByName (msg , field )
342+ return err == nil
343+ }
0 commit comments