1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
|
From abcebfe9714644d4e259e53b10e0e9417b5b864f Mon Sep 17 00:00:00 2001
From: Jerome Flesch <jflesch@openpaper.work>
Date: Sun, 21 Apr 2024 13:31:03 +0200
Subject: [PATCH] backend/guesswork/labels/sklearn: fix use of
scipy.sparse.hstack() + numpy.zeros()
Closes #1111
---
.../paperwork_backend/guesswork/label/sklearn/__init__.py | 5 +++--
paperwork-backend/src/paperwork_backend/model/fake.py | 6 ++++++
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/paperwork-backend/src/paperwork_backend/guesswork/label/sklearn/__init__.py b/paperwork-backend/src/paperwork_backend/guesswork/label/sklearn/__init__.py
index b2af4350..8633211f 100644
--- a/paperwork-backend/src/paperwork_backend/guesswork/label/sklearn/__init__.py
+++ b/paperwork-backend/src/paperwork_backend/guesswork/label/sklearn/__init__.py
@@ -191,7 +191,8 @@ class UpdatableVectorizer(object):
)
if required_padding > 0:
doc_vector = numpy.hstack([
- doc_vector, numpy.zeros((required_padding,))
+ doc_vector,
+ numpy.zeros((required_padding,))
])
if sum_features is None:
sum_features = doc_vector
@@ -339,7 +340,7 @@ class Corpus(object):
if required_padding > 0:
doc_vector = scipy.sparse.hstack([
scipy.sparse.csr_matrix(doc_vector),
- numpy.zeros((required_padding,))
+ numpy.zeros((1, required_padding))
])
else:
doc_vector = scipy.sparse.csr_matrix(doc_vector)
diff --git a/paperwork-backend/src/paperwork_backend/model/fake.py b/paperwork-backend/src/paperwork_backend/model/fake.py
index 29beae97..f06fe18e 100644
--- a/paperwork-backend/src/paperwork_backend/model/fake.py
+++ b/paperwork-backend/src/paperwork_backend/model/fake.py
@@ -125,6 +125,12 @@ class Plugin(openpaperwork_core.PluginBase):
if doc['url'] == doc_url:
out.update(doc['labels'])
+ def doc_has_labels_by_url(self, doc_url):
+ for doc in self.docs:
+ if doc['url'] == doc_url:
+ return True if len(doc["labels"]) > 0 else None
+ return None
+
def doc_add_label_by_url(self, doc_url, label, color=None):
if color is None:
all_labels = set()
--
GitLab
|