Mailing List Archive

[MediaWiki-commits] [Gerrit] search/MjoLniR[master]: add python interface to scala dbn
EBernhardson has uploaded a new change for review. ( https://gerrit.wikimedia.org/r/406069 )

Change subject: add python interface to scala dbn
......................................................................

add python interface to scala dbn

It turns out only the driver has a py4j connection to the jvm,
executors talk to spark directly through sockets. To use jvm
implementations in the executors we need to trigger that from
jvm. Added an implementation and some basic tests.

Change-Id: Iee7f79662e89bcf64cdb447aac0df5b68ee1170c
---
M jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala
M jvm/src/main/scala/org/wikimedia/search/mjolnir/DataWriter.scala
M jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala
M mjolnir/dbn.py
M mjolnir/test/conftest.py
M mjolnir/test/training/test_xgboost.py
M mjolnir/utilities/data_pipeline.py
M setup.py
8 files changed, 164 insertions(+), 197 deletions(-)


git pull ssh://gerrit.wikimedia.org:29418/search/MjoLniR refs/changes/69/406069/1

diff --git a/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala b/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala
index faac7dc..d051c7d 100644
--- a/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala
+++ b/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala
@@ -13,7 +13,12 @@
* implementation was ported from python clickmodels by Aleksandr Chuklin and the
* notes on math were added in an attempt to understand why the implementation works.
*/
+import org.apache.spark.rdd.RDD
+
import scala.collection.mutable
+import org.apache.spark.sql.{DataFrame, Row, functions => F}
+import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
+import org.apache.spark.sql.{types => T}
import org.json4s.{JArray, JBool, JString}
import org.json4s.jackson.JsonMethods

@@ -29,7 +34,7 @@
def urlToId(queryId: Int, url: String): Int = {
val urlToIdMap = queryIdToUrlToIdMap.getOrElseUpdate(queryId, { mutable.Map() })
urlToIdMap.getOrElseUpdate(url, {
- var nextUrlId = queryIdToNextUrlId.getOrElse(queryId, 0)
+ val nextUrlId = queryIdToNextUrlId.getOrElse(queryId, 0)
queryIdToNextUrlId(queryId) = nextUrlId + 1
nextUrlId
})
@@ -79,7 +84,7 @@
c
}

- val hasClicks = allClicks.take(n).exists { x => x}
+ val hasClicks = allClicks.exists { x => x }
if (urls.length < minDocsPerQuery ||
(discardNoClicks && !hasClicks)
) {
@@ -185,20 +190,14 @@
// attractiveness and satisfaction values for each position
class PositionRel(var a: Array[Double], var s: Array[Double])

-case class SessionEstimate(
- a: (Double, Double), s: (Double, Double),
- e: Array[(Double, Double)], C: Double,
- clicks: Array[Double])
-
-
class DbnModel(gamma: Double, config: Config) {
val invGamma: Double = 1D - gamma

def train(sessions: Seq[SessionItem]): Array[Array[UrlRel]] = {
// This is basically a multi-dimensional array with queryId in the first
- // dimension and urlId in the second dimension. Because queries only reference
- // a subset of the known urls we use a map at the second level instead of
- // creating the entire matrix.
+ // dimension and urlId in the second dimension. InputReader guarantees
+ // that queryId starts at 0 and is continuous, and that per-query id urlId
+ // also starts at 0 and is continuous, allowing static sized arrays to be used.
val urlRelevances: Array[Array[UrlRel]] = (0 to config.maxQueryId).map { queryId =>
(0 to config.maxUrlIds(queryId)).map { _ => new UrlRel(config.defaultRel, config.defaultRel) }.toArray
}.toArray
@@ -267,7 +266,7 @@
val queryUrlRelFrac = urlRelFractions(s.queryId)
i = 0
while (i < N) {
- var urlId = s.urlIds(i)
+ val urlId = s.urlIds(i)
// update attraction
val rel = queryUrlRelFrac(urlId)
val estA = sessionEstimate.a(i)
@@ -410,7 +409,7 @@
// (alpha, beta)
}

- var sessionEstimate = new PositionRel(new Array[Double](config.serpSize), new Array[Double](config.serpSize))
+ val sessionEstimate = new PositionRel(new Array[Double](config.serpSize), new Array[Double](config.serpSize))
// Returns
// a: P(A_i|C_i,G) - Probability of attractiveness at position i conditioned on clicked and gamma
// s: P(S_i|C_i,G) - Probability of satisfaction at position i conditioned on clicked and gamma
@@ -461,4 +460,95 @@
}
}

+private class DbnHitPage(val hitPageId: Int, val hitPosition: Double, val clicked: Boolean)

+/**
+ * Predict relevance of query/page pairs from individual user search sessions.
+ */
+object DBN {
+ // TODO: These should all be configurable? Perhaps
+ // also simplified somehow...
+ private val CLICKED = "clicked"
+ private val HITS = "hits"
+ private val HIT_PAGE_ID = "hit_page_id"
+ private val HIT_POSITION = "hit_position"
+ private val NORM_QUERY_ID = "norm_query_id"
+ private val RELEVANCE = "relevance"
+ private val SESSION_ID = "session_id"
+ private val WIKI_ID = "wikiid"
+
+ /**
+ * Given a sequence of rows representing multiple searches
+ * for a single normalized query from a single session aggregate
+ * hits into their average position and tag if it was clicked or not
+ *
+ * @param sessionHits Sequence of rows representing searches
+ * for a single normalized query and session.
+ * @return
+ */
+ private def deduplicateHits(sessionHits: Seq[Row]): (Array[String], Array[Boolean]) = {
+ val deduped = sessionHits.groupBy(_.getAs[Int](HIT_PAGE_ID))
+ .map { case (hitPageId, hits) =>
+ val hitPositions = hits.map(_.getAs[Int](HIT_POSITION))
+ val clicked = hits.exists(_.getAs[Boolean](CLICKED))
+ val avgHitPosition = hitPositions.sum.toDouble / hitPositions.length.toDouble
+ new DbnHitPage(hitPageId, avgHitPosition, clicked)
+ }
+ .toSeq.sortBy(_.hitPosition)
+ val urls = deduped.map(_.hitPageId.toString).toArray
+ val clicked = deduped.map(_.clicked).toArray
+ (urls, clicked)
+ }
+
+ val trainOutputSchema = T.StructType(
+ T.StructField(WIKI_ID, T.StringType) ::
+ T.StructField(NORM_QUERY_ID, T.LongType) ::
+ T.StructField(HIT_PAGE_ID, T.IntegerType) ::
+ T.StructField(RELEVANCE, T.DoubleType) :: Nil)
+
+ def train(df: DataFrame, dbnConfig: Map[String, String], numPartitions: Int): DataFrame = {
+ val minDocsPerQuery = dbnConfig.getOrElse("MIN_DOCS_PER_QUERY", "10").toInt
+ val serpSize = dbnConfig.getOrElse("SERP_SIZE", "10").toInt
+ val defaultRel = dbnConfig.getOrElse("DEFAULT_REL", "0.9").toFloat
+ val maxIterations = dbnConfig.getOrElse("MAX_ITERATIONS", "40").toInt
+ val gamma = dbnConfig.getOrElse("GAMMA", "0.9").toFloat
+
+ val dfGrouped = df
+ .withColumn(NORM_QUERY_ID, F.col(NORM_QUERY_ID).cast(T.LongType))
+ .withColumn(HIT_PAGE_ID, F.col(HIT_PAGE_ID).cast(T.IntegerType))
+ .withColumn(HIT_POSITION, F.col(HIT_POSITION).cast(T.IntegerType))
+ .groupBy(WIKI_ID, NORM_QUERY_ID, SESSION_ID)
+ .agg(F.collect_list(F.struct(HIT_POSITION, HIT_PAGE_ID, CLICKED)).alias(HITS))
+ .repartition(numPartitions, F.col(WIKI_ID), F.col(NORM_QUERY_ID))
+
+ val hitsIndex = dfGrouped.schema.fieldIndex(HITS)
+ val normQueryIndex = dfGrouped.schema.fieldIndex(NORM_QUERY_ID)
+ val wikiidIndex = dfGrouped.schema.fieldIndex(WIKI_ID)
+
+ val rdd: RDD[Row] = dfGrouped
+ .rdd.mapPartitions { rows: Iterator[Row] =>
+ val reader = new InputReader(minDocsPerQuery, serpSize, discardNoClicks = true)
+ val items = rows.flatMap { row =>
+ // Sorts lowest to highest
+ val (urls, clicked) = deduplicateHits(row.getSeq[Row](hitsIndex))
+ val query = row.getLong(normQueryIndex).toString
+ val region = row.getString(wikiidIndex)
+ reader.makeSessionItem(query, region, urls, clicked)
+ }.toSeq
+ if (items.isEmpty) {
+ Iterator()
+ } else {
+ // When we get a lazy seq from the iterator ensure its materialized
+ // before creating config with the mutable state.
+ items.length
+ val config = reader.config(defaultRel, maxIterations)
+ val model = new DbnModel(gamma, config)
+ reader.toRelevances(model.train(items)).map { rel =>
+ new GenericRowWithSchema(Array(rel.region, rel.query.toLong, rel.url.toInt, rel.relevance), trainOutputSchema)
+ }.toIterator
+ }
+ }
+
+ df.sqlContext.createDataFrame(rdd, trainOutputSchema)
+ }
+}
diff --git a/jvm/src/main/scala/org/wikimedia/search/mjolnir/DataWriter.scala b/jvm/src/main/scala/org/wikimedia/search/mjolnir/DataWriter.scala
index 01fcf6d..4af170f 100644
--- a/jvm/src/main/scala/org/wikimedia/search/mjolnir/DataWriter.scala
+++ b/jvm/src/main/scala/org/wikimedia/search/mjolnir/DataWriter.scala
@@ -43,9 +43,9 @@
// may contain gigabytes of data so this should do as little work
// as possible per-row.
private def writeOneFold(
- pathFormatter: (String, Int) => HDFSPath,
- config: Array[String]
- )(partitionId: Int, rows: Iterator[OutputRow]): Iterator[Map[String, String]] = {
+ pathFormatter: (String, Int) => HDFSPath,
+ config: Array[String]
+ )(partitionId: Int, rows: Iterator[OutputRow]): Iterator[Map[String, String]] = {
// .toSet.toVector gives us a unique list, but feels like hax
val paths = config.toSet.toVector.map { name: String =>
name -> pathFormatter(name, partitionId)
diff --git a/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala b/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala
index 2a6f7b7..7a4e889 100644
--- a/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala
+++ b/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala
@@ -1,11 +1,12 @@
package org.wikimedia.search.mjolnir

-import org.scalatest.FunSuite
+import org.apache.spark.sql.{Row, types => T}
+import org.apache.spark.sql.catalyst.expressions.GenericRow

import scala.io.Source
import scala.util.Random

-class DBNSuite extends FunSuite {
+class DBNSuite extends SharedSparkContext {
test("create session items") {
val ir = new InputReader(1, 20, true)
val item = ir.makeSessionItem(
@@ -207,4 +208,48 @@
//}
//writer.close()
}
+
+ private val nextSessionId: () => String = {
+ var current: Int = 0;
+ { () =>
+ current += 1
+ current.toString
+ }
+ }
+
+ private def makeSession(nHits: Integer): Seq[Row] = {
+ val sessionId = nextSessionId()
+ val nQueries = Random.nextInt(2) + 1
+ (0 until nQueries).flatMap { _ =>
+ val normQueryId = Random.nextLong() % 10
+ (0 until nHits).map { k =>
+ val pageId = Random.nextInt(100)
+ val clicked = Random.nextFloat() * (k + 1) < 0.5
+ new GenericRow(Array(
+ "testwiki", normQueryId, sessionId, k, pageId, clicked
+ ))
+ }
+ }
+ }
+
+ private val schema = T.StructType(
+ T.StructField("wikiid", T.StringType) ::
+ T.StructField("norm_query_id", T.LongType) ::
+ T.StructField("session_id", T.StringType) ::
+ T.StructField("hit_position", T.IntegerType) ::
+ T.StructField("hit_page_id", T.IntegerType) ::
+ T.StructField("clicked", T.BooleanType) ::
+ Nil)
+
+ test("train from a dataframe should not fail on simple query") {
+ val rows = makeSession(20) ++ makeSession(20)
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(rows), schema)
+ DBN.train(df, Map(), 1).collect()
+ }
+
+ test("empty partitions should not fail") {
+ val rows = makeSession(20)
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(rows), schema)
+ DBN.train(df, Map(), 200).collect()
+ }
}
diff --git a/mjolnir/dbn.py b/mjolnir/dbn.py
index 8c4ab76..a225111 100644
--- a/mjolnir/dbn.py
+++ b/mjolnir/dbn.py
@@ -1,127 +1,9 @@
"""
-Implements training a Dynamic Bayesian Network, using the clickmodels library,
-within spark
+Generate relevance probabilities from user search sessions.
"""

from __future__ import absolute_import
-from clickmodels.inference import DbnModel
-from clickmodels.input_reader import InputReader
-import json
import pyspark.sql
-from pyspark.sql import functions as F
-from pyspark.sql import types as T
-import mjolnir.spark
-
-
-def _deduplicate_hits(session_hits):
- """Deduplicate multiple views of a hit by a single session.
-
- A single session may have seen the same result list multiple times, for
- example by clicking a link, then clicking back and clicking a second link.
- Normalize that data together into a single record per hit_page_id even if
- it was displayed to a session multiple times.
-
- Parameters
- ----------
- session_hits : list
- A list of hits seen by a single session.
-
- Returns
- -------
- list
- List of hits shown to a session de-duplicated to contain only one entry
- per hit_page_id.
- """
- by_hit_page_id = {}
- for hit in session_hits:
- if hit.hit_page_id in by_hit_page_id:
- by_hit_page_id[hit.hit_page_id].append(hit)
- else:
- by_hit_page_id[hit.hit_page_id] = [hit]
-
- deduped = []
- for hit_page_id, hits in by_hit_page_id.iteritems():
- hit_positions = []
- clicked = False
- for hit in hits:
- hit_positions.append(hit.hit_position)
- clicked |= bool(hit.clicked)
- deduped.append(pyspark.sql.Row(
- hit_page_id=hit_page_id,
- hit_position=sum(hit_positions) / float(len(hit_positions)),
- clicked=clicked))
- return deduped
-
-
-def _gen_dbn_input(iterator):
- """Converts an iterator over spark rows into the DBN input format.
-
- It is perhaps undesirable that we serialize into a string with json so
- InputReader can deserialize, but it is not generic enough to avoid this
- step.
-
- Parameters
- ----------
- iterator : ???
- iterator over pyspark.sql.Row. Each row must have a wikiid,
- norm_query_id, and list of hits each containing hit_position,
- hit_page_id and clicked.
-
- Yields
- -------
- string
- Line for a single item of the input iterator formatted for use
- by clickmodels InputReader.
- """
- for row in iterator:
- results = []
- clicks = []
- deduplicated = _deduplicate_hits(row.hits)
- deduplicated.sort(key=lambda hit: hit.hit_position)
- for hit in deduplicated:
- results.append(str(hit.hit_page_id))
- clicks.append(hit.clicked)
- yield '\t'.join([.
- '0', # unused identifier
- str(row.norm_query_id), # group the session belongs to
- row.wikiid, # region
- '0', # intent weight
- json.dumps(results), # hits displayed in session
- json.dumps([False] * len(results)), # layout (unused)
- json.dumps(clicks) # Was result clicked
- ])
-
-
-def _extract_labels_from_dbn(model, reader):
- """Extracts all learned labels from the model.
-
- Paramseters
- -----------
- model : clickmodels.inference.DbnModel
- A trained DBN model
- reader : clickmodels.input_reader.InputReader
- Reader that was used to build the list of SessionItem's model was
- trained with.
-
- Returns
- -------
- list of tuples
- List of four value tuples each containing wikiid, norm_query_id,
- hit_page_id and relevance.
- """
- # reader converted all the page ids into an internal id, flip the map so we
- # can change them back. Not the most memory efficient, but it will do.
- uid_to_url = {uid: url for url, uid in reader.url_to_id.iteritems()}
- rows = []
- for (norm_query_id, wikiid), qid in reader.query_to_id.iteritems():
- # clickmodels required the group key to be a string, convert back
- # to an int to match input data
- norm_query_id = int(norm_query_id)
- for uid, data in model.urlRelevances[False][qid].iteritems():
- relevance = data['a'] * data['s']
- hit_page_id = int(uid_to_url[uid])
- rows.append((wikiid, norm_query_id, hit_page_id, relevance))
- return rows


def train(df, dbn_config, num_partitions=200):
@@ -137,8 +19,8 @@
User click logs with columns wikiid, norm_query_id, session_id,
hit_page_id, hit_position, clicked.
dbn_config : dict
- Configuration needed by the DBN. See clickmodels documentation for more
- information.
+ Configuration needed by the DBN. See scala implementation docs
+ for more information.
num_partitions : int
The number of partitions to split input data into for training.
Training will load the entire partition into python to feed into the
@@ -150,57 +32,10 @@
spark.sql.DataFrame
DataFrame with columns wikiid, norm_query_id, hit_page_id, relevance.
"""
- mjolnir.spark.assert_columns(df, ['wikiid', 'norm_query_id', 'session_id',
- 'hit_page_id', 'hit_position', 'clicked'])

- def train_partition(iterator):
- """Learn the relevance labels for a single DataFrame partition.
-
- Before applying to a partition ensure that sessions for queries are not
- split between multiple partitions.
-
- Parameters
- ----------
- iterator : iterator over pyspark.sql.Row's.
-
- Returns
- -------
- list of tuples
- List of (wikiid, norm_query_id, hit_page_id, relevance) tuples.
- """
- reader = InputReader(dbn_config['MIN_DOCS_PER_QUERY'],
- dbn_config['MAX_DOCS_PER_QUERY'],
- False,
- dbn_config['SERP_SIZE'],
- False,
- discard_no_clicks=True)
- sessions = reader(_gen_dbn_input(iterator))
- dbn_config['MAX_QUERY_ID'] = reader.current_query_id + 1
- model = DbnModel((0.9, 0.9, 0.9, 0.9), config=dbn_config)
- model.train(sessions)
- return _extract_labels_from_dbn(model, reader)
-
- rdd_rel = (
- df
- # group and collect up the hits for individual (wikiid, norm_query_id,
- # session_id) tuples to match how the dbn expects to receive data.
- .groupby('wikiid', 'norm_query_id', 'session_id')
- .agg(F.collect_list(F.struct('hit_position', 'hit_page_id', 'clicked')).alias('hits'))
- # Partition into small batches ensuring that all matching (wikiid,
- # norm_query_id) rows end up on the same partition.
- # TODO: The above groupby and this repartition both cause a shuffle, is
- # it possible to make that a single shuffle? Could push the final level
- # of grouping into python, but that could just as well end up worse?
- .repartition(num_partitions, 'wikiid', 'norm_query_id')
- # Run each partition through the DBN to generate relevance scores.
- .rdd.mapPartitions(train_partition))
-
- # Using toDF() is very slow as it has to run some of the partitions to check their
- # types, and then run all the partitions later to get the actual data. To prevent
- # running twice specify the schema we expect.
- return df.sql_ctx.createDataFrame(rdd_rel, T.StructType([.
- T.StructField('wikiid', T.StringType(), False),
- T.StructField('norm_query_id', T.LongType(), False),
- T.StructField('hit_page_id', T.LongType(), False),
- T.StructField('relevance', T.DoubleType(), False)
- ]))
+ jvm = df._sc._jvm
+ # jvm side expects Map[String, String]
+ j_config = jvm.PythonUtils.toScalaMap({str(k): str(v) for k, v in dbn_config.items()})
+ assert j_config.size() == len(dbn_config)
+ j_df = jvm.org.wikimedia.search.mjolnir.DBN.train(df._jdf, j_config, num_partitions)
+ return pyspark.sql.DataFrame(j_df, df.sql_ctx)
diff --git a/mjolnir/test/conftest.py b/mjolnir/test/conftest.py
index 0ab729c..efc8441 100644
--- a/mjolnir/test/conftest.py
+++ b/mjolnir/test/conftest.py
@@ -71,7 +71,7 @@
# Maven coordinates of jvm dependencies
.set('spark.jars.packages', ','.join([
'ml.dmlc:xgboost4j-spark:0.8-wmf-1',
- 'org.wikimedia.search:mjolnir:0.3',
+ 'org.wikimedia.search:mjolnir:0.4-SNAPSHOT',
'org.apache.spark:spark-streaming-kafka-0-8_2.11:2.1.0']))
# By default spark will shuffle to 200 partitions, which is
# way too many for our small test cases. This cuts execution
diff --git a/mjolnir/test/training/test_xgboost.py b/mjolnir/test/training/test_xgboost.py
index e49e496..100a3e8 100644
--- a/mjolnir/test/training/test_xgboost.py
+++ b/mjolnir/test/training/test_xgboost.py
@@ -1,6 +1,7 @@
from __future__ import absolute_import
import mjolnir.training.xgboost
from pyspark.ml.linalg import Vectors
+import pyspark.sql
import pytest


diff --git a/mjolnir/utilities/data_pipeline.py b/mjolnir/utilities/data_pipeline.py
index 21f0d7e..ae63de3 100644
--- a/mjolnir/utilities/data_pipeline.py
+++ b/mjolnir/utilities/data_pipeline.py
@@ -90,12 +90,9 @@
df_rel = (
mjolnir.dbn.train(df_sampled, num_partitions=dbn_partitions, dbn_config={
'MAX_ITERATIONS': 40,
- 'DEBUG': False,
- 'PRETTY_LOG': True,
'MIN_DOCS_PER_QUERY': 10,
'MAX_DOCS_PER_QUERY': 20,
'SERP_SIZE': 20,
- 'QUERY_INDEPENDENT_PAGER': False,
'DEFAULT_REL': 0.5})
# naive conversion of relevance % into a label
.withColumn('label', (F.col('relevance') * 10).cast('int')))
diff --git a/setup.py b/setup.py
index 9bd27e4..878a24d 100644
--- a/setup.py
+++ b/setup.py
@@ -4,7 +4,6 @@

requirements = [
# mjolnir requirements
- 'clickmodels',
'requests',
'kafka',
'pyyaml',

--
To view, visit https://gerrit.wikimedia.org/r/406069
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings

Gerrit-MessageType: newchange
Gerrit-Change-Id: Iee7f79662e89bcf64cdb447aac0df5b68ee1170c
Gerrit-PatchSet: 1
Gerrit-Project: search/MjoLniR
Gerrit-Branch: master
Gerrit-Owner: EBernhardson <ebernhardson@wikimedia.org>

_______________________________________________
MediaWiki-commits mailing list
MediaWiki-commits@lists.wikimedia.org
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits