Skip to content

Allow overriding XMLReader used in parsing #636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions jvm/src/test/scala/scala/xml/XMLTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,34 @@ class XMLTestJVM {
def namespaceAware2: Unit =
roundtrip(namespaceAware = true, """<book xmlns="http://docbook.org/ns/docbook" xmlns:xi="http://www.w3.org/2001/XInclude"><svg xmlns:svg="http://www.w3.org/2000/svg"/></book>""")

@UnitTest
def useXMLReaderWithXMLFilter(): Unit = {
val parent: org.xml.sax.XMLReader = javax.xml.parsers.SAXParserFactory.newInstance.newSAXParser.getXMLReader
val filter: org.xml.sax.XMLFilter = new org.xml.sax.helpers.XMLFilterImpl(parent) {
override def characters(ch: Array[Char], start: Int, length: Int): Unit = {
for (i <- 0 until length) if (ch(start+i) == 'a') ch(start+i) = 'b'
super.characters(ch, start, length)
}
}
assertEquals(XML.withXMLReader(filter).loadString("<a>caffeeaaay</a>").toString, "<a>cbffeebbby</a>")
}

@UnitTest
def checkThatErrorHandlerIsNotOverwritten(): Unit = {
var gotAnError: Boolean = false
XML.reader.setErrorHandler(new org.xml.sax.ErrorHandler {
override def warning(e: SAXParseException): Unit = gotAnError = true
override def error(e: SAXParseException): Unit = gotAnError = true
override def fatalError(e: SAXParseException): Unit = gotAnError = true
})
try {
XML.loadString("<a>")
} catch {
case _: org.xml.sax.SAXParseException =>
}
assertTrue(gotAnError)
}

@UnitTest
def nodeSeqNs: Unit = {
val x = {
Expand Down
8 changes: 6 additions & 2 deletions shared/src/main/scala/scala/xml/XML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ package scala
package xml

import factory.XMLLoader
import java.io.{ File, FileDescriptor, FileInputStream, FileOutputStream }
import java.io.{ InputStream, Reader, StringReader }
import java.io.{File, FileDescriptor, FileInputStream, FileOutputStream}
import java.io.{InputStream, Reader, StringReader}
import java.nio.channels.Channels
import scala.util.control.Exception.ultimately

Expand Down Expand Up @@ -72,6 +72,10 @@ object XML extends XMLLoader[Elem] {
def withSAXParser(p: SAXParser): XMLLoader[Elem] =
new XMLLoader[Elem] { override val parser: SAXParser = p }

/** Returns an XMLLoader whose load* methods will use the supplied XMLReader. */
def withXMLReader(r: XMLReader): XMLLoader[Elem] =
new XMLLoader[Elem] { override val reader: XMLReader = r }

/**
* Saves a node to a file with given filename using given encoding
* optionally with xmldecl and doctype declaration.
Expand Down
74 changes: 46 additions & 28 deletions shared/src/main/scala/scala/xml/factory/XMLLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ package scala
package xml
package factory

import org.xml.sax.SAXNotRecognizedException
import org.xml.sax.{SAXNotRecognizedException, XMLReader}
import javax.xml.parsers.SAXParserFactory
import parsing.{FactoryAdapter, NoBindingFactoryAdapter}
import java.io.{File, FileDescriptor, InputStream, Reader}
Expand Down Expand Up @@ -46,59 +46,77 @@ trait XMLLoader[T <: Node] {
/* Override this to use a different SAXParser. */
def parser: SAXParser = parserInstance.get

/* Override this to use a different XMLReader. */
def reader: XMLReader = parser.getXMLReader

/**
* Loads XML from the given InputSource, using the supplied parser.
* The methods available in scala.xml.XML use the XML parser in the JDK.
*/
def loadXML(source: InputSource, parser: SAXParser): T = {
val result: FactoryAdapter = parse(source, parser)
def loadXML(source: InputSource, parser: SAXParser): T = loadXML(source, parser.getXMLReader)

def loadXMLNodes(source: InputSource, parser: SAXParser): Seq[Node] = loadXMLNodes(source, parser.getXMLReader)

private def loadXML(source: InputSource, reader: XMLReader): T = {
val result: FactoryAdapter = parse(source, reader)
result.rootElem.asInstanceOf[T]
}

def loadXMLNodes(source: InputSource, parser: SAXParser): Seq[Node] = {
val result: FactoryAdapter = parse(source, parser)
private def loadXMLNodes(source: InputSource, reader: XMLReader): Seq[Node] = {
val result: FactoryAdapter = parse(source, reader)
result.prolog ++ (result.rootElem :: result.epilogue)
}

private def parse(source: InputSource, parser: SAXParser): FactoryAdapter = {
private def parse(source: InputSource, reader: XMLReader): FactoryAdapter = {
if (source == null) throw new IllegalArgumentException("InputSource cannot be null")

val result: FactoryAdapter = adapter

reader.setContentHandler(result)
reader.setDTDHandler(result)
/* Do not overwrite pre-configured EntityResolver. */
if (reader.getEntityResolver == null) reader.setEntityResolver(result)
/* Do not overwrite pre-configured ErrorHandler. */
if (reader.getErrorHandler == null) reader.setErrorHandler(result)

try {
parser.setProperty("http://xml.org/sax/properties/lexical-handler", result)
reader.setProperty("http://xml.org/sax/properties/lexical-handler", result)
} catch {
case _: SAXNotRecognizedException =>
}

result.scopeStack = TopScope :: result.scopeStack
parser.parse(source, result)
reader.parse(source)
result.scopeStack = result.scopeStack.tail

result
}

/** loads XML from given InputSource. */
def load(source: InputSource): T = loadXML(source, reader)

/** Loads XML from the given file, file descriptor, or filename. */
def loadFile(file: File): T = loadXML(fromFile(file), parser)
def loadFile(fd: FileDescriptor): T = loadXML(fromFile(fd), parser)
def loadFile(name: String): T = loadXML(fromFile(name), parser)
def loadFile(file: File): T = load(fromFile(file))
def loadFile(fd: FileDescriptor): T = load(fromFile(fd))
def loadFile(name: String): T = load(fromFile(name))

/** loads XML from given InputStream, Reader, sysID, InputSource, or URL. */
def load(is: InputStream): T = loadXML(fromInputStream(is), parser)
def load(reader: Reader): T = loadXML(fromReader(reader), parser)
def load(sysID: String): T = loadXML(fromSysId(sysID), parser)
def load(source: InputSource): T = loadXML(source, parser)
def load(url: URL): T = loadXML(fromInputStream(url.openStream()), parser)
/** loads XML from given InputStream, Reader, sysID, or URL. */
def load(is: InputStream): T = load(fromInputStream(is))
def load(reader: Reader): T = load(fromReader(reader))
def load(sysID: String): T = load(fromSysId(sysID))
def load(url: URL): T = load(fromInputStream(url.openStream()))

/** Loads XML from the given String. */
def loadString(string: String): T = loadXML(fromString(string), parser)
def loadString(string: String): T = load(fromString(string))

/** Load XML nodes, including comments and processing instructions that precede and follow the root element. */
def loadFileNodes(file: File): Seq[Node] = loadXMLNodes(fromFile(file), parser)
def loadFileNodes(fd: FileDescriptor): Seq[Node] = loadXMLNodes(fromFile(fd), parser)
def loadFileNodes(name: String): Seq[Node] = loadXMLNodes(fromFile(name), parser)
def loadNodes(is: InputStream): Seq[Node] = loadXMLNodes(fromInputStream(is), parser)
def loadNodes(reader: Reader): Seq[Node] = loadXMLNodes(fromReader(reader), parser)
def loadNodes(sysID: String): Seq[Node] = loadXMLNodes(fromSysId(sysID), parser)
def loadNodes(source: InputSource): Seq[Node] = loadXMLNodes(source, parser)
def loadNodes(url: URL): Seq[Node] = loadXMLNodes(fromInputStream(url.openStream()), parser)
def loadStringNodes(string: String): Seq[Node] = loadXMLNodes(fromString(string), parser)
def loadNodes(source: InputSource): Seq[Node] = loadXMLNodes(source, reader)
def loadFileNodes(file: File): Seq[Node] = loadNodes(fromFile(file))
def loadFileNodes(fd: FileDescriptor): Seq[Node] = loadNodes(fromFile(fd))
def loadFileNodes(name: String): Seq[Node] = loadNodes(fromFile(name))
def loadNodes(is: InputStream): Seq[Node] = loadNodes(fromInputStream(is))
def loadNodes(reader: Reader): Seq[Node] = loadNodes(fromReader(reader))
def loadNodes(sysID: String): Seq[Node] = loadNodes(fromSysId(sysID))
def loadNodes(url: URL): Seq[Node] = loadNodes(fromInputStream(url.openStream()))
def loadStringNodes(string: String): Seq[Node] = loadNodes(fromString(string))
}
1 change: 1 addition & 0 deletions shared/src/main/scala/scala/xml/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,6 @@ package object xml {
type SAXParseException = org.xml.sax.SAXParseException
type EntityResolver = org.xml.sax.EntityResolver
type InputSource = org.xml.sax.InputSource
type XMLReader = org.xml.sax.XMLReader
type SAXParser = javax.xml.parsers.SAXParser
}
5 changes: 3 additions & 2 deletions shared/src/main/scala/scala/xml/parsing/MarkupParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ trait MarkupParser extends MarkupParserCommon with TokenTests {
var extIndex = -1

/** holds temporary values of pos */
// Note: this is clearly an override, but if marked as such it causes a "...cannot override a mutable variable"
// error with Scala 3; does it work with Scala 3 if not explicitly marked as an override remains to be seen...
// Note: if marked as an override, this causes a "...cannot override a mutable variable" error with Scala 3;
// SethTisue noted on Oct 14, 2021 that lampepfl/dotty#13744 should fix it - and it probably did,
// but Scala XML still builds against Scala 3 version that has this bug, so this still can not be marked as an override :(
var tmppos: Int = _

/** holds the next character */
Expand Down