"""
Query wrappers and search utilities.
"""
from future_builtins import filter, map
import itertools
import bisect
import heapq
import lucene
class ArrayList(lucene.ArrayList):
def __init__(self, values=()):
lucene.ArrayList.__init__(self)
for value in values:
self.add(value)
[docs]class Query(object):
"""Inherited lucene Query, with dynamic base class acquisition.
Uses class methods and operator overloading for convenient query construction.
"""
def __new__(cls, base, *args):
return base.__new__(type(base.__name__, (cls, base), {}))
def __init__(self, base, *args):
base.__init__(self, *args)
[docs] def filter(self, cache=True):
"Return lucene CachingWrapperFilter, optionally just QueryWrapperFilter."
filter = lucene.QueryWrapperFilter(self)
return lucene.CachingWrapperFilter(filter) if cache else filter
[docs] def terms(self):
"Generate set of query term items."
terms = lucene.HashSet().of_(lucene.Term)
self.extractTerms(terms)
for term in terms:
yield term.field(), term.text()
@classmethod
[docs] def term(cls, name, value):
"Return lucene TermQuery."
return cls(lucene.TermQuery, lucene.Term(name, value))
@classmethod
def boolean(cls, occur, *queries, **terms):
self = BooleanQuery(lucene.BooleanQuery)
for query in queries:
self.add(query, occur)
for name, values in terms.items():
for value in ([values] if isinstance(values, basestring) else values):
self.add(cls.term(name, value), occur)
return self
@classmethod
[docs] def any(cls, *queries, **terms):
"Return `BooleanQuery`_ (OR) from queries and terms."
return cls.boolean(lucene.BooleanClause.Occur.SHOULD, *queries, **terms)
@classmethod
[docs] def all(cls, *queries, **terms):
"Return `BooleanQuery`_ (AND) from queries and terms."
return cls.boolean(lucene.BooleanClause.Occur.MUST, *queries, **terms)
@classmethod
[docs] def disjunct(cls, multiplier, *queries, **terms):
"Return lucene DisjunctionMaxQuery from queries and terms."
self = cls(lucene.DisjunctionMaxQuery, ArrayList(queries), multiplier)
for name, values in terms.items():
for value in ([values] if isinstance(values, basestring) else values):
self.add(cls.term(name, value))
return self
@classmethod
[docs] def span(cls, *term):
"Return `SpanQuery`_ from term name and value or a MultiTermQuery."
if len(term) <= 1:
return SpanQuery(lucene.SpanMultiTermQueryWrapper, *term)
return SpanQuery(lucene.SpanTermQuery, lucene.Term(*term))
@classmethod
[docs] def near(cls, name, *values, **kwargs):
"""Return :meth:`SpanNearQuery <SpanQuery.near>` from terms.
Term values which supply another field name will be masked."""
spans = (cls.span(name, value) if isinstance(value, basestring) else cls.span(*value).mask(name) for value in values)
return SpanQuery.near(*spans, **kwargs)
@classmethod
[docs] def prefix(cls, name, value):
"Return lucene PrefixQuery."
return cls(lucene.PrefixQuery, lucene.Term(name, value))
@classmethod
[docs] def range(cls, name, start, stop, lower=True, upper=False):
"Return lucene RangeQuery, by default with a half-open interval."
return cls(lucene.TermRangeQuery, name, start, stop, lower, upper)
@classmethod
[docs] def phrase(cls, name, *values):
"Return lucene PhraseQuery. None may be used as a placeholder."
self = cls(lucene.PhraseQuery)
for index, value in enumerate(values):
if value is not None:
self.add(lucene.Term(name, value), index)
return self
@classmethod
[docs] def multiphrase(cls, name, *values):
"Return lucene MultiPhraseQuery. None may be used as a placeholder."
self = cls(lucene.MultiPhraseQuery)
for index, words in enumerate(values):
if isinstance(words, basestring):
words = [words]
if words is not None:
self.add([lucene.Term(name, word) for word in words], index)
return self
@classmethod
[docs] def wildcard(cls, name, value):
"Return lucene WildcardQuery."
return cls(lucene.WildcardQuery, lucene.Term(name, value))
@classmethod
[docs] def fuzzy(cls, name, value, minimumSimilarity=0.5, prefixLength=0):
"Return lucene FuzzyQuery."
return cls(lucene.FuzzyQuery, lucene.Term(name, value), minimumSimilarity, prefixLength)
def __pos__(self):
return Query.all(self)
def __neg__(self):
return Query.boolean(lucene.BooleanClause.Occur.MUST_NOT, self)
[docs] def __and__(self, other):
return Query.all(self, other)
def __rand__(self, other):
return Query.all(other, self)
[docs] def __or__(self, other):
return Query.any(self, other)
def __ror__(self, other):
return Query.any(other, self)
[docs] def __sub__(self, other):
return Query.any(self).__isub__(other)
def __rsub__(self, other):
return Query.any(other).__isub__(self)
[docs]class BooleanQuery(Query):
"Inherited lucene BooleanQuery with sequence interface to clauses."
[docs] def __len__(self):
return len(self.getClauses())
[docs] def __iter__(self):
return iter(self.getClauses())
[docs] def __getitem__(self, index):
return self.getClauses()[index]
[docs] def __iand__(self, other):
self.add(other, lucene.BooleanClause.Occur.MUST)
return self
[docs] def __ior__(self, other):
self.add(other, lucene.BooleanClause.Occur.SHOULD)
return self
[docs] def __isub__(self, other):
self.add(other, lucene.BooleanClause.Occur.MUST_NOT)
return self
[docs]class SpanQuery(Query):
"Inherited lucene SpanQuery with additional span constructors."
[docs] def __getitem__(self, slc):
start, stop, step = slc.indices(lucene.Integer.MAX_VALUE)
assert step == 1, 'slice step is not supported'
if start == 0:
return SpanQuery(lucene.SpanFirstQuery, self, stop)
return SpanQuery(lucene.SpanPositionRangeQuery, self, start, stop)
[docs] def __sub__(self, other):
return SpanQuery(lucene.SpanNotQuery, self, other)
[docs] def __or__(*spans):
return SpanQuery(lucene.SpanOrQuery, spans)
[docs] def near(*spans, **kwargs):
"""Return lucene SpanNearQuery from SpanQueries.
:param slop: default 0
:param inOrder: default True
:param collectPayloads: default True
"""
args = map(kwargs.get, ('slop', 'inOrder', 'collectPayloads'), (0, True, True))
return SpanQuery(lucene.SpanNearQuery, spans, *args)
[docs] def mask(self, name):
"Return lucene FieldMaskingSpanQuery, which allows combining span queries from different fields."
return SpanQuery(lucene.FieldMaskingSpanQuery, self, name)
[docs] def payload(self, *values):
"Return lucene SpanPayloadCheckQuery from payload values."
base = lucene.SpanNearPayloadCheckQuery if lucene.SpanNearQuery.instance_(self) else lucene.SpanPayloadCheckQuery
return SpanQuery(base, self, ArrayList(map(lucene.JArray_byte, values)))
class Collector(lucene.PythonCollector):
"Collect all ids and scores efficiently."
def __init__(self):
lucene.PythonCollector.__init__(self)
self.scores = {}
def collect(self, id, score):
self.scores[id + self.base] = score
def setNextReader(self, reader, base):
self.base = base
def acceptsDocsOutOfOrder(self):
return True
def sorted(self, key=None, reverse=False):
"Return ordered ids and scores."
ids = sorted(self.scores)
if key is None:
key, reverse = self.scores.__getitem__, True
ids.sort(key=key, reverse=reverse)
return ids, list(map(self.scores.__getitem__, ids))
[docs]class SortField(lucene.SortField):
"""Inherited lucene SortField used for caching FieldCache parsers.
:param name: field name
:param type: type object or name compatible with SortField constants
:param parser: lucene FieldCache.Parser or callable applied to field values
:param reverse: reverse flag used with sort
"""
def __init__(self, name, type='string', parser=None, reverse=False):
type = self.typename = getattr(type, '__name__', type).capitalize()
if parser is None:
parser = getattr(lucene.SortField, type.upper())
elif not lucene.FieldCache.Parser.instance_(parser):
base = getattr(lucene, 'Python{0}Parser'.format(type))
namespace = {'parse' + type: staticmethod(parser)}
parser = object.__class__(base.__name__, (base,), namespace)()
lucene.SortField.__init__(self, name, parser, reverse)
[docs] def comparator(self, reader):
"Return indexed values from default FieldCache using the given reader."
method = getattr(lucene.FieldCache.DEFAULT, 'get{0}s'.format(self.typename))
args = [self.parser] * bool(self.parser)
readers = reader.sequentialSubReaders
if lucene.MultiReader.instance_(reader):
readers = itertools.chain.from_iterable(reader.sequentialSubReaders for reader in readers)
arrays = [method(reader, self.field, *args) for reader in readers]
if len(arrays) <= 1:
return arrays[0]
cls, = set(map(type, arrays))
index, result = 0, cls(sum(map(len, arrays)))
for array in arrays:
lucene.System.arraycopy(array, 0, result, index, len(array))
index += len(array)
return result
[docs]class Highlighter(lucene.Highlighter):
"""Inherited lucene Highlighter with stored analysis options.
:param searcher: `IndexSearcher`_ used for analysis, scoring, and optionally text retrieval
:param query: lucene Query
:param field: field name of text
:param terms: highlight any matching term in query regardless of position
:param fields: highlight matching terms from any field
:param tag: optional html tag name
:param formatter: optional lucene Formatter
:param encoder: optional lucene Encoder
"""
def __init__(self, searcher, query, field, terms=False, fields=False, tag='', formatter=None, encoder=None):
if tag:
formatter = lucene.SimpleHTMLFormatter('<{0}>'.format(tag), '</{0}>'.format(tag))
scorer = (lucene.QueryTermScorer if terms else lucene.QueryScorer)(query, *(searcher.indexReader, field) * (not fields))
lucene.Highlighter.__init__(self, *filter(None, [formatter, encoder, scorer]))
self.searcher, self.field = searcher, field
self.selector = lucene.MapFieldSelector([field])
[docs] def fragments(self, doc, count=1):
"""Return highlighted text fragments.
:param doc: text string or doc id to be highlighted
:param count: maximum number of fragments
"""
if not isinstance(doc, basestring):
doc = self.searcher.doc(doc, self.selector)[self.field]
return doc and list(self.getBestFragments(self.searcher.analyzer, self.field, doc, count))
[docs]class FastVectorHighlighter(getattr(lucene, 'FastVectorHighlighter', object)):
"""Inherited lucene FastVectorHighlighter with stored query.
Fields must be stored and have term vectors with offsets and positions.
:param searcher: `IndexSearcher`_ with stored term vectors
:param query: lucene Query
:param field: field name of text
:param terms: highlight any matching term in query regardless of position
:param fields: highlight matching terms from any field
:param tag: optional html tag name
:param fragListBuilder: optional lucene FragListBuilder
:param fragmentsBuilder: optional lucene FragmentsBuilder
"""
def __init__(self, searcher, query, field, terms=False, fields=False, tag='', fragListBuilder=None, fragmentsBuilder=None):
if tag:
fragmentsBuilder = lucene.SimpleFragmentsBuilder(['<{0}>'.format(tag)], ['</{0}>'.format(tag)])
args = fragListBuilder or lucene.SimpleFragListBuilder(), fragmentsBuilder or lucene.SimpleFragmentsBuilder()
lucene.FastVectorHighlighter.__init__(self, not terms, not fields, *args)
self.searcher, self.field = searcher, field
self.query = self.getFieldQuery(query)
[docs] def fragments(self, id, count=1, size=100):
"""Return highlighted text fragments.
:param id: document id
:param count: maximum number of fragments
:param size: maximum number of characters in fragment
"""
return list(self.getBestFragments(self.query, self.searcher.indexReader, id, self.field, size, count))
[docs]class SpellChecker(dict):
"""Correct spellings and suggest words for queries.
Supply a vocabulary mapping words to (reverse) sort keys, such as document frequencies.
"""
def __init__(self, *args, **kwargs):
dict.__init__(self, *args, **kwargs)
self.words = sorted(self)
self.alphabet = sorted(set(itertools.chain.from_iterable(self.words)))
self.suffix = self.alphabet[-1] * max(map(len, self.words)) if self.alphabet else ''
self.prefixes = set(word[:stop] for word in self.words for stop in range(len(word) + 1))
[docs] def suggest(self, prefix, count=None):
"Return ordered suggested words for prefix."
start = bisect.bisect_left(self.words, prefix)
stop = bisect.bisect_right(self.words, prefix + self.suffix, start)
words = self.words[start:stop]
if count is not None and count < len(words):
return heapq.nlargest(count, words, key=self.__getitem__)
words.sort(key=self.__getitem__, reverse=True)
return words
[docs] def edits(self, word, length=0):
"Return set of potential words one edit distance away, mapped to valid prefix lengths."
pairs = [(word[:index], word[index:]) for index in range(len(word) + 1)]
deletes = (head + tail[1:] for head, tail in pairs[:-1])
transposes = (head + tail[1::-1] + tail[2:] for head, tail in pairs[:-2])
edits = {} if length else dict.fromkeys(itertools.chain(deletes, transposes), 0)
for head, tail in pairs[length:]:
if head not in self.prefixes:
break
for char in self.alphabet:
prefix = head + char
if prefix in self.prefixes:
edits[prefix + tail] = edits[prefix + tail[1:]] = len(prefix)
return edits
[docs] def correct(self, word):
"Generate ordered sets of words by increasing edit distance."
previous, edits = set(), {word: 0}
for distance in range(len(word)):
yield sorted(filter(self.__contains__, edits), key=self.__getitem__, reverse=True)
previous.update(edits)
groups = map(self.edits, edits, edits.values())
edits = dict((edit, group[edit]) for group in groups for edit in group if edit not in previous)
[docs]class SpellParser(lucene.PythonQueryParser):
"""Inherited lucene QueryParser which corrects spelling.
Assign a searcher attribute or override :meth:`correct` implementation.
"""
[docs] def correct(self, term):
"Return term with text replaced as necessary."
field = term.field()
for text in self.searcher.correct(field, term.text()):
return lucene.Term(field, text)
return term
[docs] def rewrite(self, query):
"Return term or phrase query with corrected terms substituted."
if lucene.TermQuery.instance_(query):
term = lucene.TermQuery.cast_(query).term
return lucene.TermQuery(self.correct(term))
query = lucene.PhraseQuery.cast_(query)
phrase = lucene.PhraseQuery()
for position, term in zip(query.positions, query.terms):
phrase.add(self.correct(term), position)
return phrase
def getFieldQuery(self, field, text, *args):
query = lucene.PythonQueryParser.getFieldQuery(self, field, text, *args)
return query if args else self.rewrite(query)
def getFieldQuery_quoted(self, *args):
return self.rewrite(self.getFieldQuery_quoted_super(*args))
def getFieldQuery_slop(self, *args):
return self.rewrite(self.getFieldQuery_slop_super(*args))