Commit 28426491 authored by Kunshan Wang's avatar Kunshan Wang

WIP: HAIL parser.

parent 9551ff05
......@@ -13,6 +13,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.immutable.Stream
import java.io.StringWriter
import java.nio.CharBuffer
import uvm.utils.AntlrHelpers._
class UIRTextReader(val idFactory: IDFactory) {
import UIRTextReader._
......@@ -40,23 +41,6 @@ class UIRTextReader(val idFactory: IDFactory) {
read(sb.toString(), globalBundle)
}
class AccumulativeAntlrErrorListener(source: String) extends BaseErrorListener {
val buf = new ArrayBuffer[String]()
var hasError = false
lazy val sourceLines = ArrayBuffer(source.lines.toSeq: _*)
override def syntaxError(recognizer: Recognizer[_, _], offendingSymbol: Object,
line: Int, charPositionInLine: Int, msg: String, e: RecognitionException): Unit = {
val theLine = sourceLines(line - 1)
val marker = " " * charPositionInLine + "^"
buf.add("line %d:%d %s\n%s\n%s".format(line, charPositionInLine, msg, theLine, marker))
hasError = true
}
def getMessages(): String = buf.mkString("\n")
}
def read(source: String, ais: ANTLRInputStream, globalBundle: GlobalBundle): TrantientBundle = {
val ea = new AccumulativeAntlrErrorListener(source)
......@@ -89,12 +73,12 @@ class UIRTextReader(val idFactory: IDFactory) {
val neg = sign match {
case "+" => false
case "-" => true
case "" => false
case "" => false
}
val abs = prefix match {
case "0x" => BigInt(nums, 16)
case "0" => if (nums == "") BigInt(0) else BigInt(nums, 8)
case "" => BigInt(nums, 10)
case "0" => if (nums == "") BigInt(0) else BigInt(nums, 8)
case "" => BigInt(nums, 10)
}
return if (neg) -abs else abs
}
......@@ -108,7 +92,7 @@ class UIRTextReader(val idFactory: IDFactory) {
java.lang.Float.NEGATIVE_INFINITY
else java.lang.Float.POSITIVE_INFINITY
}
case _: FloatNanContext => java.lang.Float.NaN
case _: FloatNanContext => java.lang.Float.NaN
case bits: FloatBitsContext => java.lang.Float.intBitsToFloat(bits.intLiteral().intValue())
}
......@@ -119,7 +103,7 @@ class UIRTextReader(val idFactory: IDFactory) {
java.lang.Double.NEGATIVE_INFINITY
else java.lang.Double.POSITIVE_INFINITY
}
case _: DoubleNanContext => java.lang.Double.NaN
case _: DoubleNanContext => java.lang.Double.NaN
case bits: DoubleBitsContext => java.lang.Double.longBitsToDouble(bits.intLiteral().longValue())
}
......@@ -128,15 +112,6 @@ class UIRTextReader(val idFactory: IDFactory) {
// Printing context information (line, column, near some token)
def inCtx(ctx: ParserRuleContext, s: String): String = nearTok(ctx.getStart, s)
def inCtx(ctx: TerminalNode, s: String): String = nearTok(ctx.getSymbol, s)
def nearTok(tok: Token, s: String): String = {
val line = tok.getLine()
val column = tok.getCharPositionInLine()
val near = tok.getText()
return "At %d:%d near '%s': %s".format(line, column, near, s)
}
def catchIn[T](ctx: ParserRuleContext, s: String)(func: => T): T = try {
func
......@@ -234,24 +209,24 @@ class UIRTextReader(val idFactory: IDFactory) {
def mkType(tc: TypeConstructorContext): Type = {
val ty = tc match {
case t: TypeIntContext => TypeInt(t.length.intValue())
case t: TypeFloatContext => TypeFloat()
case t: TypeDoubleContext => TypeDouble()
case t: TypeRefContext => TypeRef(null).later(phase1) { _.ty = t.ty }
case t: TypeIRefContext => TypeIRef(null).later(phase1) { _.ty = t.ty }
case t: TypeWeakRefContext => TypeWeakRef(null).later(phase1) { _.ty = t.ty }
case t: TypeStructContext => TypeStruct(null).later(phase1) { _.fieldTys = t.fieldTys.map(resTy) }
case t: TypeArrayContext => TypeArray(null, t.length.longValue()).later(phase1) { _.elemTy = t.ty }
case t: TypeHybridContext => TypeHybrid(null, null).later(phase1) { tt => tt.fieldTys = t.fieldTys.map(resTy); tt.varTy = t.varTy }
case t: TypeVoidContext => TypeVoid()
case t: TypeFuncRefContext => TypeFuncRef(null).later(phase1) { _.sig = t.funcSig() }
case t: TypeIntContext => TypeInt(t.length.intValue())
case t: TypeFloatContext => TypeFloat()
case t: TypeDoubleContext => TypeDouble()
case t: TypeRefContext => TypeRef(null).later(phase1) { _.ty = t.ty }
case t: TypeIRefContext => TypeIRef(null).later(phase1) { _.ty = t.ty }
case t: TypeWeakRefContext => TypeWeakRef(null).later(phase1) { _.ty = t.ty }
case t: TypeStructContext => TypeStruct(null).later(phase1) { _.fieldTys = t.fieldTys.map(resTy) }
case t: TypeArrayContext => TypeArray(null, t.length.longValue()).later(phase1) { _.elemTy = t.ty }
case t: TypeHybridContext => TypeHybrid(null, null).later(phase1) { tt => tt.fieldTys = t.fieldTys.map(resTy); tt.varTy = t.varTy }
case t: TypeVoidContext => TypeVoid()
case t: TypeFuncRefContext => TypeFuncRef(null).later(phase1) { _.sig = t.funcSig() }
case t: TypeThreadRefContext => TypeThreadRef()
case t: TypeStackRefContext => TypeStackRef()
case t: TypeTagRef64Context => TypeTagRef64()
case t: TypeVectorContext => TypeVector(null, t.length.longValue()).later(phase1) { _.elemTy = t.ty }
case t: TypeUPtrContext => TypeUPtr(null).later(phase1) { _.ty = t.ty }
case t: TypeUFuncPtrContext => TypeUFuncPtr(null).later(phase1) { _.sig = t.funcSig }
case _ => throw new TextIRParsingException("foo")
case t: TypeStackRefContext => TypeStackRef()
case t: TypeTagRef64Context => TypeTagRef64()
case t: TypeVectorContext => TypeVector(null, t.length.longValue()).later(phase1) { _.elemTy = t.ty }
case t: TypeUPtrContext => TypeUPtr(null).later(phase1) { _.ty = t.ty }
case t: TypeUFuncPtrContext => TypeUFuncPtr(null).later(phase1) { _.sig = t.funcSig }
case _ => throw new TextIRParsingException("foo")
}
return ty
}
......@@ -286,8 +261,8 @@ class UIRTextReader(val idFactory: IDFactory) {
def mkConst(t: Type, c: ConstConstructorContext): Constant = {
val con = c match {
case cc: CtorIntContext => ConstInt(t, cc.intLiteral)
case cc: CtorFloatContext => ConstFloat(t, cc.floatLiteral)
case cc: CtorIntContext => ConstInt(t, cc.intLiteral)
case cc: CtorFloatContext => ConstFloat(t, cc.floatLiteral)
case cc: CtorDoubleContext => ConstDouble(t, cc.doubleLiteral)
case cc: CtorListContext => ConstSeq(t, null).later(phase2) {
_.elems = for (gn <- cc.GLOBAL_NAME()) yield resGlobalVar(gn)
......@@ -471,7 +446,7 @@ class UIRTextReader(val idFactory: IDFactory) {
implicit def resNewStackClause(nsc: NewStackClauseContext): NewStackAction = {
nsc match {
case a: NewStackPassValueContext => PassValues(a.typeList(), a.argList())
case a: NewStackThrowExcContext => ThrowExc(a.exc)
case a: NewStackThrowExcContext => ThrowExc(a.exc)
}
}
......@@ -651,7 +626,7 @@ class UIRTextReader(val idFactory: IDFactory) {
inst.id = idFactory.getID()
inst.name = Option(instDef.name).map(n => globalize(n.getText, bbName))
bb.localInstNs.add(inst)
val instRess: Seq[InstResult] = Option(instDef.instResults) match {
......@@ -700,7 +675,7 @@ object UIRTextReader {
sigil match {
case '@' => name
case '%' => parentName + "." + name.substring(1)
case _ => throw new UvmException("Illegal name '%s'. Name must begin with either '@' or '%%'".format(name))
case _ => throw new UvmException("Illegal name '%s'. Name must begin with either '@' or '%%'".format(name))
}
}
}
\ No newline at end of file
......@@ -23,3 +23,6 @@ class UvmDivisionByZeroException(message: String = null, cause: Throwable = null
/** Thrown when accessing Mu memory but the address is outside the allocated region. */
class UvmIllegalMemoryAccessException(message: String = null, cause: Throwable = null) extends UvmRuntimeException(message, cause)
/** Thrown on syntax errors in HAIL scripts. */
class UvmHailParsingException(message: String = null, cause: Throwable = null) extends UvmRefImplException(message, cause)
package uvm.refimpl.hail
import uvm.refimpl.MicroVM
import scala.collection.JavaConversions.asScalaBuffer
import scala.collection.mutable.HashMap
import org.antlr.v4.runtime.ANTLRInputStream
import org.antlr.v4.runtime.CommonTokenStream
import org.antlr.v4.runtime.ParserRuleContext
import uvm.ir.textinput.TextIRParsingException
import uvm.ir.textinput.gen.HAILLexer
import uvm.ir.textinput.gen.HAILParser
import uvm.ir.textinput.gen.HAILParser.DoubleBitsContext
import uvm.ir.textinput.gen.HAILParser.DoubleInfContext
import uvm.ir.textinput.gen.HAILParser.DoubleLiteralContext
import uvm.ir.textinput.gen.HAILParser.DoubleNanContext
import uvm.ir.textinput.gen.HAILParser.DoubleNumberContext
import uvm.ir.textinput.gen.HAILParser.FixedAllocContext
import uvm.ir.textinput.gen.HAILParser.FloatBitsContext
import uvm.ir.textinput.gen.HAILParser.FloatInfContext
import uvm.ir.textinput.gen.HAILParser.FloatLiteralContext
import uvm.ir.textinput.gen.HAILParser.FloatNanContext
import uvm.ir.textinput.gen.HAILParser.FloatNumberContext
import uvm.ir.textinput.gen.HAILParser.HailContext
import uvm.ir.textinput.gen.HAILParser.HybridAllocContext
import uvm.ir.textinput.gen.HAILParser.IntGlobalContext
import uvm.ir.textinput.gen.HAILParser.IntLitContext
import uvm.ir.textinput.gen.HAILParser.IntLiteralContext
import uvm.ir.textinput.gen.HAILParser.LValueContext
import uvm.ir.textinput.gen.HAILParser.MemInitContext
import uvm.ir.textinput.gen.HAILParser.RValueContext
import uvm.ir.textinput.gen.HAILParser.TypeContext
import uvm.refimpl.MicroVM
import uvm.refimpl.MuCtx
import uvm.refimpl.MuIRefValue
import uvm.refimpl.MuRefValue
import uvm.refimpl.UvmHailParsingException
import uvm.refimpl.UvmHailParsingException
import uvm.refimpl.UvmHailParsingException
import uvm.refimpl.UvmHailParsingException
import uvm.ssavariables.ConstInt
import uvm.types.Type
import uvm.types.TypeHybrid
import uvm.utils.AntlrHelpers.AccumulativeAntlrErrorListener
import uvm.utils.AntlrHelpers.inCtx
import uvm.ir.textinput.gen.HAILParser.IntExprContext
import uvm.refimpl.MuStructValue
import uvm.types.TypeStruct
import uvm.refimpl.UvmHailParsingException
import uvm.types.AbstractSeqType
import uvm.refimpl.mem.HeaderUtils
import uvm.refimpl.itpr.MemoryOperations
import uvm.refimpl.mem.MemorySupport
class HailScriptLoader(implicit microVM: MicroVM) {
class HailScriptLoader(implicit microVM: MicroVM, memorySupport: MemorySupport) {
def loadHail(hailScript: String): Unit = {
val ais = new ANTLRInputStream(hailScript)
val ea = new AccumulativeAntlrErrorListener(hailScript)
val lexer = new HAILLexer(ais)
lexer.removeErrorListeners()
lexer.addErrorListener(ea)
val tokens = new CommonTokenStream(lexer)
val parser = new HAILParser(tokens)
parser.removeErrorListeners()
parser.addErrorListener(ea)
val ast = parser.hail()
if (ea.hasError) {
throw new TextIRParsingException("Syntax error:\n" + ea.getMessages)
}
loadTopLevel(ast)
}
def catchIn[T](ctx: ParserRuleContext, s: String)(func: => T): T = try {
func
} catch {
case e: UvmHailParsingException => throw new UvmHailParsingException(inCtx(ctx, e.getMessage), e)
case e: Exception => throw new UvmHailParsingException(inCtx(ctx, s), e)
}
implicit def resTy(ctx: TypeContext): Type = catchIn(ctx, "Unable to resolve type") { resTyByName(ctx.getText) }
private def resTyByName(name: String): Type = microVM.globalBundle.typeNs(name)
implicit def resConstInt(ctx: IntGlobalContext): BigInt = catchIn(ctx, "Unable to resolve constant int") { resConstIntByName(ctx.getText) }
private def resConstIntByName(name: String): BigInt = {
val const = microVM.globalBundle.constantNs.get(name).getOrElse {
throw new UvmHailParsingException("Type %s not found".format(name))
}
const match {
case ConstInt(ty, num) => num
case _ => throw new UvmHailParsingException("Expected constant int. Found %s: ty=".format(const.repr, const.constTy))
}
}
private type HailObjMap = HashMap[String, MuRefValue]
private def loadTopLevel(ast: HailContext): Unit = {
implicit val mc = microVM.newContext()
try {
implicit val hailObjMap = new HailObjMap
ast.topLevelDef.map(_.getChild(0)) foreach {
case tl: FixedAllocContext => {
val hailName = tl.nam.toString
val ty = resTy(tl.ty)
if (ty.isInstanceOf[TypeHybrid]) {
throw new UvmHailParsingException(inCtx(tl, "Cannot allocate hybrid using '.new'. Found: %s".format(ty)))
}
val obj = mc.newFixed(ty.id)
if (hailObjMap.contains(hailName)) {
throw new UvmHailParsingException(inCtx(tl, "HAIL name %s already used.".format(hailName)))
}
hailObjMap(hailName) = obj
}
case tl: HybridAllocContext => {
val hailName = tl.nam.toString
val ty = resTy(tl.ty)
if (!ty.isInstanceOf[TypeHybrid]) {
throw new UvmHailParsingException(inCtx(tl, "hybrid required. Found %s".format(ty)))
}
val len: Long = evalIntExpr(tl.len).toLong
val obj = mc.newFixed(ty.id)
if (hailObjMap.contains(hailName)) {
throw new UvmHailParsingException(inCtx(tl, "HAIL name %s already used.".format(hailName)))
}
hailObjMap(hailName) = obj
}
case init: MemInitContext => {
val lv = evalLValue(init.lv)
assign(lv, init.rv)
}
}
} finally {
mc.closeContext()
}
}
class LValue private (val iref: MuIRefValue, val varLen: Option[Long], val baseCtx: ParserRuleContext, val curCtx: ParserRuleContext) {
def indexInto(index: Long, ctx: ParserRuleContext)(implicit mc: MuCtx): LValue = {
val (newIRef, newVarLen) = varLen match {
case None => { // not in the var-part of a hybrid
iref.ty.ty match {
case t: TypeStruct => {
val ii = index.toInt
if (ii < 0 || ii >= t.fieldTys.length) {
throw new UvmHailParsingException(inCtx(ctx, "Index out of bound. Struct %s has %d fields. Found index: %d".format(
t, t.fieldTys.length, ii)))
}
val nir = mc.getFieldIRef(iref, index.toInt)
(nir, None)
}
case t: TypeHybrid => {
val ii = index.toInt
if (ii < 0 || ii > t.fieldTys.length) {
throw new UvmHailParsingException(inCtx(ctx, "Index out of bound. Hybrid %s has %d fields. Found index: %d".format(
t, t.fieldTys.length, ii)))
}
if (ii == t.fieldTys.length) {
val nir = mc.getVarPartIRef(iref)
// For debug purpose, we keep the upperbound recorded. Out-of-bound access has undefined behaviour.
val len = HeaderUtils.getVarLength(iref.vb.objRef)
(nir, Some(len))
} else {
val nir = mc.getFieldIRef(iref, index.toInt)
(nir, None)
}
}
case t: AbstractSeqType => {
val ii = index.toLong
if (ii < 0 || ii >= t.len) {
throw new UvmHailParsingException(inCtx(ctx, "Index out of bound. Sequence type %s has %d elements. Found index: %d".format(
t, t.len, ii)))
}
val hII = mc.handleFromInt(ii, 64)
val nir = mc.getElemIRef(iref, hII)
mc.deleteValue(hII)
(nir, None)
}
}
}
case Some(l) => { // in the var-part of a hybrid
val ii = index.toLong
if (ii < 0 || ii >= l) {
throw new UvmHailParsingException(inCtx(ctx, "Index out of bound. Hybrid %s has %d actual var-part elements. Found index: %d".format(
iref.ty, l, ii)))
}
val hII = mc.handleFromInt(ii, 64)
val nir = mc.shiftIRef(iref, hII)
mc.deleteValue(hII)
(nir, None)
}
}
new LValue(newIRef, newVarLen, baseCtx, ctx)
}
}
object LValue {
def forName(name: String, baseCtx: ParserRuleContext)(implicit mc: MuCtx, hailObjMap: HailObjMap): LValue = {
val base = name.charAt(0) match {
case '@' => {
val global = microVM.globalBundle.globalCellNs.get(name).getOrElse {
throw new UvmHailParsingException(inCtx(baseCtx, "Global cell %s not found".format(name)))
}
val gc = mc.handleFromGlobal(global.id)
gc
}
case '$' => {
val ref = hailObjMap.getOrElse(name, {
throw new UvmHailParsingException(inCtx(baseCtx, "HAIL name %s not defined. It needs to be defined BEFORE this".format(name)))
})
val iref = mc.getIRef(ref)
iref
}
}
new LValue(base, None, baseCtx, baseCtx)
}
}
def evalLValue(lv: LValueContext)(implicit mc: MuCtx, hailObjMap: HailObjMap): LValue = {
val base = LValue.forName(lv.nam.getText, lv)
var cur: LValue = base
for (ic <- lv.indices) {
val index = evalIntExpr(ic.intExpr()).toLong
val newCur = cur.indexInto(index, ic)
val oldCur = cur
cur = newCur
mc.deleteValue(oldCur.iref)
}
cur
}
/**
* Index into an iref in a general way.
* @param varLen: None if base is not the var part. Some(l) if base is the 0-th elem of the var part of a hybrid whose actual length is l.
*/
private def generalIndex(base: MuIRefValue, index: Long, varLen: Option[Long])(
implicit mc: MuCtx, ctx: ParserRuleContext): MuIRefValue = {
varLen match {
case None => { // not in the var-part of a hybrid
base.ty.ty match {
case t: TypeStruct => {
val ii = index.toInt
if (ii < 0 || ii >= t.fieldTys.length) {
throw new UvmHailParsingException(inCtx(ctx, "Index out of bound. Struct %s has %d fields. Found index: %d".format(
t, t.fieldTys.length, ii)))
}
mc.getFieldIRef(base, index.toInt)
}
case t: TypeHybrid => {
val ii = index.toInt
if (ii < 0 || ii > t.fieldTys.length) {
throw new UvmHailParsingException(inCtx(ctx, "Index out of bound. Hybrid %s has %d fields. Found index: %d".format(
t, t.fieldTys.length, ii)))
}
mc.getFieldIRef(base, index.toInt)
}
case t: AbstractSeqType => {
val ii = index.toLong
if (ii < 0 || ii >= t.len) {
throw new UvmHailParsingException(inCtx(ctx, "Index out of bound. Sequence type %s has %d elements. Found index: %d".format(
t, t.len, ii)))
}
val hII = mc.handleFromInt(ii, 64)
val nc = mc.getElemIRef(base, hII)
mc.deleteValue(hII)
nc
}
}
}
case Some(l) => { // in the var-part of a hybrid
val ii = index.toLong
if (ii < 0 || ii >= l) {
throw new UvmHailParsingException(inCtx(ctx, "Index out of bound. Hybrid %s has %d actual var-part elements. Found index: %d".format(
base.ty, l, ii)))
}
val hII = mc.handleFromInt(ii, 64)
val nc = mc.shiftIRef(base, hII)
mc.deleteValue(hII)
nc
}
}
}
def assign(lv: LValue, rv: RValueContext): Unit = {
???
}
def evalIntExpr(ie: IntExprContext): BigInt = {
ie match {
case i: IntLitContext => IntLiteralToBigInt(i.intLiteral())
case i: IntGlobalContext => resConstInt(i)
}
}
val IntRe = """([+-]?)(0x|0|)([0-9a-fA-F]*)""".r
implicit def IntLiteralToBigInt(il: IntLiteralContext): BigInt = {
val txt = il.getText()
txt match {
case IntRe(sign, prefix, nums) => {
val neg = sign match {
case "+" => false
case "-" => true
case "" => false
}
val abs = prefix match {
case "0x" => BigInt(nums, 16)
case "0" => if (nums == "") BigInt(0) else BigInt(nums, 8)
case "" => BigInt(nums, 10)
}
return if (neg) -abs else abs
}
}
}
implicit def floatLiteralToFloat(fl: FloatLiteralContext): Float = fl match {
case num: FloatNumberContext => num.FP_NUM.getText.toFloat
case fi: FloatInfContext => {
if (fi.getText.startsWith("-"))
java.lang.Float.NEGATIVE_INFINITY
else java.lang.Float.POSITIVE_INFINITY
}
case _: FloatNanContext => java.lang.Float.NaN
case bits: FloatBitsContext => java.lang.Float.intBitsToFloat(bits.intLiteral().intValue())
}
implicit def doubleLiteralToDouble(dl: DoubleLiteralContext): Double = dl match {
case num: DoubleNumberContext => num.FP_NUM.getText.toDouble
case fi: DoubleInfContext => {
if (fi.getText.startsWith("-"))
java.lang.Double.NEGATIVE_INFINITY
else java.lang.Double.POSITIVE_INFINITY
}
case _: DoubleNanContext => java.lang.Double.NaN
case bits: DoubleBitsContext => java.lang.Double.longBitsToDouble(bits.intLiteral().longValue())
}
}
\ No newline at end of file
package uvm.utils
import org.antlr.v4.runtime.RecognitionException
import scala.collection.mutable.ArrayBuffer
import org.antlr.v4.runtime.Recognizer
import org.antlr.v4.runtime.BaseErrorListener
import org.antlr.v4.runtime.tree.TerminalNode
import org.antlr.v4.runtime.ParserRuleContext
import org.antlr.v4.runtime.Token
object AntlrHelpers {
class AccumulativeAntlrErrorListener(source: String) extends BaseErrorListener {
val buf = new ArrayBuffer[String]()
var hasError = false
lazy val sourceLines = ArrayBuffer(source.lines.toSeq: _*)
override def syntaxError(recognizer: Recognizer[_, _], offendingSymbol: Object,
line: Int, charPositionInLine: Int, msg: String, e: RecognitionException): Unit = {
val theLine = sourceLines(line - 1)
val marker = " " * charPositionInLine + "^"
buf += "line %d:%d %s\n%s\n%s".format(line, charPositionInLine, msg, theLine, marker)
hasError = true
}
def getMessages(): String = buf.mkString("\n")
}
def inCtx(ctx: ParserRuleContext, s: String): String = nearTok(ctx.getStart, s)
def inCtx(ctx: TerminalNode, s: String): String = nearTok(ctx.getSymbol, s)
def nearTok(tok: Token, s: String): String = {
val line = tok.getLine()
val column = tok.getCharPositionInLine()
val near = tok.getText()
return "At %d:%d near '%s': %s".format(line, column, near, s)
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment