Skip to content

XMLLoader returns Document #663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions jvm/src/test/scala/scala/xml/XMLTest.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
package scala.xml

import language.postfixOps

import org.junit.{Test => UnitTest}
import org.junit.Assert.assertTrue
import org.junit.Assert.assertFalse
import org.junit.Assert.assertEquals
import org.junit.Assert.{assertEquals, assertFalse, assertTrue}
import java.io.StringWriter
import java.io.ByteArrayOutputStream
import java.io.StringReader
import scala.xml.dtd.{DocType, PublicID}
import scala.xml.parsing.ConstructingParser
import scala.xml.Utility.sort
Expand Down Expand Up @@ -610,7 +606,7 @@ class XMLTestJVM {
| section]]> </b> suffix</a>""".stripMargin)
}

def roundtripNodes(xml: String): Unit = assertEquals(xml, XML.loadStringNodes(xml).map(_.toString).mkString(""))
def roundtripNodes(xml: String): Unit = assertEquals(xml, XML.loadStringDocument(xml).children.map(_.toString).mkString(""))

@UnitTest
def xmlLoaderLoadNodes(): Unit = {
Expand Down
39 changes: 24 additions & 15 deletions shared/src/main/scala/scala/xml/XML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package xml
import factory.XMLLoader
import java.io.{File, FileDescriptor, FileInputStream, FileOutputStream, InputStream, Reader, StringReader, Writer}
import java.nio.channels.Channels
import scala.util.control.Exception.ultimately
import scala.util.control.Exception

object Source {
def fromFile(name: String): InputSource = fromFile(new File(name))
Expand All @@ -25,8 +25,8 @@ object Source {
def fromSysId(sysID: String): InputSource = new InputSource(sysID)
def fromFile(fd: FileDescriptor): InputSource = fromInputStream(new FileInputStream(fd))
def fromInputStream(is: InputStream): InputSource = new InputSource(is)
def fromReader(reader: Reader): InputSource = new InputSource(reader)
def fromString(string: String): InputSource = fromReader(new StringReader(string))
def fromReader(reader: Reader): InputSource = new InputSource(reader)
}

/**
Expand Down Expand Up @@ -68,12 +68,14 @@ object XML extends XMLLoader[Elem] {
val encoding: String = "UTF-8"

/** Returns an XMLLoader whose load* methods will use the supplied SAXParser. */
def withSAXParser(p: SAXParser): XMLLoader[Elem] =
new XMLLoader[Elem] { override val parser: SAXParser = p }
def withSAXParser(p: SAXParser): XMLLoader[Elem] = new XMLLoader[Elem] {
override val parser: SAXParser = p
}

/** Returns an XMLLoader whose load* methods will use the supplied XMLReader. */
def withXMLReader(r: XMLReader): XMLLoader[Elem] =
new XMLLoader[Elem] { override val reader: XMLReader = r }
def withXMLReader(r: XMLReader): XMLLoader[Elem] = new XMLLoader[Elem] {
override val reader: XMLReader = r
}

/**
* Saves a node to a file with given filename using given encoding
Expand All @@ -94,15 +96,15 @@ object XML extends XMLLoader[Elem] {
node: Node,
enc: String = "UTF-8",
xmlDecl: Boolean = false,
doctype: dtd.DocType = null): Unit =
{
val fos: FileOutputStream = new FileOutputStream(filename)
val w: Writer = Channels.newWriter(fos.getChannel, enc)
doctype: dtd.DocType = null
): Unit = {
val fos: FileOutputStream = new FileOutputStream(filename)
val w: Writer = Channels.newWriter(fos.getChannel, enc)

ultimately(w.close())(
write(w, node, enc, xmlDecl, doctype)
)
}
Exception.ultimately(w.close())(
write(w, node, enc, xmlDecl, doctype)
)
}

/**
* Writes the given node using writer, optionally with xml decl and doctype.
Expand All @@ -114,7 +116,14 @@ object XML extends XMLLoader[Elem] {
* @param xmlDecl if true, write xml declaration
* @param doctype if not null, write doctype declaration
*/
final def write(w: java.io.Writer, node: Node, enc: String, xmlDecl: Boolean, doctype: dtd.DocType, minimizeTags: MinimizeMode.Value = MinimizeMode.Default): Unit = {
final def write(
w: Writer,
node: Node,
enc: String,
xmlDecl: Boolean,
doctype: dtd.DocType,
minimizeTags: MinimizeMode.Value = MinimizeMode.Default
): Unit = {
/* TODO: optimize by giving writer parameter to toXML*/
if (xmlDecl) w.write("<?xml version='1.0' encoding='" + enc + "'?>\n")
if (doctype ne null) w.write(doctype.toString + "\n")
Expand Down
107 changes: 42 additions & 65 deletions shared/src/main/scala/scala/xml/factory/XMLLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ package scala
package xml
package factory

import org.xml.sax.{SAXNotRecognizedException, SAXNotSupportedException, XMLReader}
import org.xml.sax.XMLReader
import scala.xml.Source
import javax.xml.parsers.SAXParserFactory
import parsing.{FactoryAdapter, NoBindingFactoryAdapter}
import java.io.{File, FileDescriptor, InputStream, Reader}
import java.net.URL

Expand All @@ -25,9 +25,6 @@ import java.net.URL
* created by "def parser" or the reader created by "def reader".
*/
trait XMLLoader[T <: Node] {
import scala.xml.Source._
def adapter: FactoryAdapter = new NoBindingFactoryAdapter()

private def setSafeDefaults(parserFactory: SAXParserFactory): Unit = {
parserFactory.setFeature("http://javax.xml.XMLConstants/feature/secure-processing", true)
parserFactory.setFeature("http://apache.org/xml/features/nonvalidating/load-external-dtd", false)
Expand All @@ -54,69 +51,49 @@ trait XMLLoader[T <: Node] {
def reader: XMLReader = parser.getXMLReader

/**
* Loads XML from the given InputSource, using the supplied parser.
* Loads XML from the given InputSource, using the supplied parser or reader.
* The methods available in scala.xml.XML use the XML parser in the JDK
* (unless another parser is present on the classpath).
*/
def loadXML(inputSource: InputSource, parser: SAXParser): T = loadXML(inputSource, parser.getXMLReader)

def loadXMLNodes(inputSource: InputSource, parser: SAXParser): Seq[Node] = loadXMLNodes(inputSource, parser.getXMLReader)

private def loadXML(inputSource: InputSource, reader: XMLReader): T = {
val result: FactoryAdapter = parse(inputSource, reader)
result.rootElem.asInstanceOf[T]
}

private def loadXMLNodes(inputSource: InputSource, reader: XMLReader): Seq[Node] = {
val result: FactoryAdapter = parse(inputSource, reader)
result.prolog ++ (result.rootElem :: result.epilogue)
}

private def parse(inputSource: InputSource, xmlReader: XMLReader): FactoryAdapter = {
if (inputSource == null) throw new IllegalArgumentException("InputSource cannot be null")

val result: FactoryAdapter = adapter

xmlReader.setContentHandler(result)
xmlReader.setDTDHandler(result)
/* Do not overwrite pre-configured EntityResolver. */
if (xmlReader.getEntityResolver == null) xmlReader.setEntityResolver(result)
/* Do not overwrite pre-configured ErrorHandler. */
if (xmlReader.getErrorHandler == null) xmlReader.setErrorHandler(result)

try {
xmlReader.setProperty("http://xml.org/sax/properties/lexical-handler", result)
} catch {
case _: SAXNotRecognizedException =>
case _: SAXNotSupportedException =>
}

result.scopeStack = TopScope :: result.scopeStack
xmlReader.parse(inputSource)
result.scopeStack = result.scopeStack.tail

result
}

/** Loads XML. */
def load(inputSource: InputSource): T = loadXML(inputSource, reader)
def loadFile(fileName: String): T = load(fromFile(fileName))
def loadFile(file: File): T = load(fromFile(file))
def load(url: URL): T = load(fromUrl(url))
def load(sysId: String): T = load(fromSysId(sysId))
def loadFile(fileDescriptor: FileDescriptor): T = load(fromFile(fileDescriptor))
def load(inputStream: InputStream): T = load(fromInputStream(inputStream))
def load(reader: Reader): T = load(fromReader(reader))
def loadString(string: String): T = load(fromString(string))
private def getDocElem(document: Document): T = document.docElem.asInstanceOf[T]

def loadXML(inputSource: InputSource, parser: SAXParser): T = getDocElem(loadDocument(inputSource, parser))
def loadXMLNodes(inputSource: InputSource, parser: SAXParser): Seq[Node] = loadDocument(inputSource, parser).children

private def loadDocument(inputSource: InputSource, parser: SAXParser): Document = adapter.loadDocument(inputSource, parser)
private def loadDocument(inputSource: InputSource, reader: XMLReader): Document = adapter.loadDocument(inputSource, reader)
def adapter: parsing.FactoryAdapter = new parsing.NoBindingFactoryAdapter()

/** Loads XML Document. */
def loadDocument(source: InputSource): Document = loadDocument(source, reader)
def loadFileDocument(fileName: String): Document = loadDocument(Source.fromFile(fileName))
def loadFileDocument(file: File): Document = loadDocument(Source.fromFile(file))
def loadDocument(url: URL): Document = loadDocument(Source.fromUrl(url))
def loadDocument(sysId: String): Document = loadDocument(Source.fromSysId(sysId))
def loadFileDocument(fileDescriptor: FileDescriptor): Document = loadDocument(Source.fromFile(fileDescriptor))
def loadDocument(inputStream: InputStream): Document = loadDocument(Source.fromInputStream(inputStream))
def loadDocument(reader: Reader): Document = loadDocument(Source.fromReader(reader))
def loadStringDocument(string: String): Document = loadDocument(Source.fromString(string))

/** Loads XML element. */
def load(inputSource: InputSource): T = getDocElem(loadDocument(inputSource))
def loadFile(fileName: String): T = getDocElem(loadFileDocument(fileName))
def loadFile(file: File): T = getDocElem(loadFileDocument(file))
def load(url: URL): T = getDocElem(loadDocument(url))
def load(sysId: String): T = getDocElem(loadDocument(sysId))
def loadFile(fileDescriptor: FileDescriptor): T = getDocElem(loadFileDocument(fileDescriptor))
def load(inputStream: InputStream): T = getDocElem(loadDocument(inputStream))
def load(reader: Reader): T = getDocElem(loadDocument(reader))
def loadString(string: String): T = getDocElem(loadStringDocument(string))

/** Load XML nodes, including comments and processing instructions that precede and follow the root element. */
def loadNodes(inputSource: InputSource): Seq[Node] = loadXMLNodes(inputSource, reader)
def loadFileNodes(fileName: String): Seq[Node] = loadNodes(fromFile(fileName))
def loadFileNodes(file: File): Seq[Node] = loadNodes(fromFile(file))
def loadNodes(url: URL): Seq[Node] = loadNodes(fromUrl(url))
def loadNodes(sysId: String): Seq[Node] = loadNodes(fromSysId(sysId))
def loadFileNodes(fileDescriptor: FileDescriptor): Seq[Node] = loadNodes(fromFile(fileDescriptor))
def loadNodes(inputStream: InputStream): Seq[Node] = loadNodes(fromInputStream(inputStream))
def loadNodes(reader: Reader): Seq[Node] = loadNodes(fromReader(reader))
def loadStringNodes(string: String): Seq[Node] = loadNodes(fromString(string))
def loadNodes(inputSource: InputSource): Seq[Node] = loadDocument(inputSource).children
def loadFileNodes(fileName: String): Seq[Node] = loadFileDocument(fileName).children
def loadFileNodes(file: File): Seq[Node] = loadFileDocument(file).children
def loadNodes(url: URL): Seq[Node] = loadDocument(url).children
def loadNodes(sysId: String): Seq[Node] = loadDocument(sysId).children
def loadFileNodes(fileDescriptor: FileDescriptor): Seq[Node] = loadFileDocument(fileDescriptor).children
def loadNodes(inputStream: InputStream): Seq[Node] = loadDocument(inputStream).children
def loadNodes(reader: Reader): Seq[Node] = loadDocument(reader).children
def loadStringNodes(string: String): Seq[Node] = loadStringDocument(string).children
}
Loading