update
This commit is contained in:
parent
889e458873
commit
f2b1080f55
|
@ -0,0 +1,29 @@
|
|||
@misc{Yannic_2022,
|
||||
title = {GPT-4chan},
|
||||
url = {https://gpt-4chan.com/},
|
||||
author = {Yannic, Kilcher},
|
||||
year = 2022
|
||||
}
|
||||
|
||||
@article{papasavva2020raiders,
|
||||
title = {Raiders of the Lost Kek: 3.5 Years of Augmented 4chan Posts from the Politically Incorrect Board},
|
||||
author = {Antonis Papasavva, Savvas Zannettou, Emiliano De Cristofaro, Gianluca Stringhini, Jeremy Blackburn},
|
||||
journal = {14th International AAAI Conference On Web And Social Media (ICWSM), 2020},
|
||||
year = 2020
|
||||
}
|
||||
|
||||
@misc{mesh-transformer-jax,
|
||||
author = {Wang, Ben},
|
||||
title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
|
||||
howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
|
||||
year = 2021,
|
||||
month = May
|
||||
}
|
||||
|
||||
@misc{gpt-j,
|
||||
author = {Wang, Ben and Komatsuzaki, Aran},
|
||||
title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
|
||||
howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
|
||||
year = 2021,
|
||||
month = May
|
||||
}
|
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
19
README.md
19
README.md
|
@ -0,0 +1,19 @@
|
|||
# gpt model
|
||||
- the model is 23.5gb so im storing it offline, it can be found here.
|
||||
- https://archive.org/download/gpt4chan_model
|
||||
|
||||
# gpt-4chan-public
|
||||
Code for GPT-4chan
|
||||
|
||||
Note: This repository only contains helper code and small changes I made to other libraries.
|
||||
The source code to the actual model is here at [https://github.com/kingoflolz/mesh-transformer-jax/](https://github.com/kingoflolz/mesh-transformer-jax/)
|
||||
|
||||
|
||||
Data here: [https://zenodo.org/record/3606810](https://zenodo.org/record/3606810)
|
||||
|
||||
Model here: [https://huggingface.co/ykilcher/gpt-4chan](https://huggingface.co/ykilcher/gpt-4chan)
|
||||
|
||||
Website here: [https://gpt-4chan.com](https://gpt-4chan.com)
|
||||
|
||||
|
||||
Also, I will not release the bot code.
|
|
@ -0,0 +1,36 @@
|
|||
{
|
||||
"activation_function": "gelu_new",
|
||||
"architectures": [
|
||||
"GPTJForCausalLM"
|
||||
],
|
||||
"attn_pdrop": 0.0,
|
||||
"bos_token_id": 50256,
|
||||
"embd_pdrop": 0.0,
|
||||
"eos_token_id": 50256,
|
||||
"gradient_checkpointing": false,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"model_type": "gptj",
|
||||
"n_embd": 4096,
|
||||
"n_head": 16,
|
||||
"n_layer": 28,
|
||||
"n_positions": 2048,
|
||||
"rotary_dim": 64,
|
||||
"summary_activation": null,
|
||||
"summary_first_dropout": 0.1,
|
||||
"summary_proj_to_labels": true,
|
||||
"summary_type": "cls_index",
|
||||
"summary_use_proj": true,
|
||||
"transformers_version": "4.10.0.dev0",
|
||||
"tokenizer_class": "GPT2Tokenizer",
|
||||
"task_specific_params": {
|
||||
"text-generation": {
|
||||
"do_sample": true,
|
||||
"temperature": 1.0,
|
||||
"max_length": 50
|
||||
}
|
||||
},
|
||||
"torch_dtype": "float16",
|
||||
"use_cache": true,
|
||||
"vocab_size": 50400
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,59 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<files>
|
||||
<file name="CITATION.bib" source="original">
|
||||
<mtime>1654573420</mtime>
|
||||
<size>1004</size>
|
||||
<md5>3aefaf02182111c11581a4b6c4506c34</md5>
|
||||
<crc32>cf4f74e9</crc32>
|
||||
<sha1>aed49afc259e203cf158a7c3b99a79b486807d6f</sha1>
|
||||
<format>Unknown</format>
|
||||
<viruscheck>1654573957</viruscheck>
|
||||
</file>
|
||||
<file name="LICENSE.txt" source="original">
|
||||
<mtime>1654573565</mtime>
|
||||
<size>11358</size>
|
||||
<md5>3b83ef96387f14655fc854ddc3c6bd57</md5>
|
||||
<crc32>86e2b4b4</crc32>
|
||||
<sha1>2b8b815229aa8a61e483fb4ba0588b8b6c491890</sha1>
|
||||
<format>Text</format>
|
||||
<viruscheck>1654573957</viruscheck>
|
||||
</file>
|
||||
<file name="gpt4chan_model_archive.torrent" source="metadata">
|
||||
<btih>bf7434b0050c76127cff1296ab4e6b52ea2c4ac0</btih>
|
||||
<mtime>1654573953</mtime>
|
||||
<size>117775</size>
|
||||
<md5>8c63d8759fb4f69a90427f929616d2d9</md5>
|
||||
<crc32>7413b27b</crc32>
|
||||
<sha1>f905188b4b5d9eb4922917bb547822f7ad146a95</sha1>
|
||||
<format>Archive BitTorrent</format>
|
||||
</file>
|
||||
<file name="gpt4chan_model_files.xml" source="original">
|
||||
<format>Metadata</format>
|
||||
<md5>3779f4857a47de3299d79141aa9b7338</md5>
|
||||
<summation>md5</summation>
|
||||
</file>
|
||||
<file name="gpt4chan_model_meta.sqlite" source="original">
|
||||
<mtime>1654573950</mtime>
|
||||
<size>20480</size>
|
||||
<md5>07b3bfccb6319d1ee47c709f24160c05</md5>
|
||||
<crc32>85495edc</crc32>
|
||||
<sha1>4398f142157d5383b6c7aa1074931cf648a289b5</sha1>
|
||||
<format>Metadata</format>
|
||||
</file>
|
||||
<file name="gpt4chan_model_meta.xml" source="original">
|
||||
<mtime>1654572291</mtime>
|
||||
<size>1342</size>
|
||||
<md5>89177f0660e01603ca310e5043fbf952</md5>
|
||||
<crc32>1b79cb16</crc32>
|
||||
<sha1>19c645562751d62b82dd9382e79e043aa98a73b1</sha1>
|
||||
<format>Metadata</format>
|
||||
</file>
|
||||
<file name="pytorch_model.bin" source="original">
|
||||
<mtime>1654566951</mtime>
|
||||
<size>24207819307</size>
|
||||
<md5>833c1dc19b7450e4e559a9917b7d076a</md5>
|
||||
<crc32>f14169d7</crc32>
|
||||
<sha1>1837f8cb55c5c6570fa66049ad4b14781145c8b5</sha1>
|
||||
<format>Unknown</format>
|
||||
</file>
|
||||
</files>
|
Binary file not shown.
|
@ -0,0 +1,18 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<metadata>
|
||||
<identifier>gpt4chan_model</identifier>
|
||||
<collection>open_source_software</collection>
|
||||
<collection>community</collection>
|
||||
<licenseurl>https://www.apache.org/licenses/LICENSE-2.0</licenseurl>
|
||||
<scanner>Internet Archive Python library 3.0.1</scanner>
|
||||
<mediatype>data</mediatype>
|
||||
<uploader>valentino.giudice96@gmail.com</uploader>
|
||||
<title> GPT-4chan Model</title>
|
||||
<publicdate>2022-06-07 01:56:14</publicdate>
|
||||
<addeddate>2022-06-07 01:56:14</addeddate>
|
||||
<curation>[curator]validator@archive.org[/curator][date]20220607020703[/date][comment]checked for malware[/comment]</curation>
|
||||
<creator> Yannic Kilcher</creator>
|
||||
<description><div><div>GPT-4chan is a language model fine-tuned from <a href="https://huggingface.co/EleutherAI/gpt-j-6B" rel="nofollow">GPT-J 6B</a> on 3.5 years worth of data from 4chan's politically incorrect (/pol/) board, as included in the dataset <span style="border-style:solid;border-color:rgb(229,231,235);"><a href="https://zenodo.org/record/3606810" rel="nofollow">Raiders of the Lost Kek: 3.5 Years of Augmented 4chan Posts from the Politically Incorrect Board</a></span>.</div></div></description>
|
||||
<publisher> Yannic Kilcher</publisher>
|
||||
<language>English</language>
|
||||
</metadata>
|
|
@ -0,0 +1,77 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
import lm_eval.tasks
|
||||
|
||||
m1 = 'GPT-J-6B'
|
||||
m2 = 'GPT-4chan'
|
||||
|
||||
log_dir = Path('./eval_logs')
|
||||
all_tasks = set()
|
||||
model_data = {}
|
||||
for fn in log_dir.rglob('log_*.stdout.txt'):
|
||||
try:
|
||||
file_text = fn.read_text()
|
||||
data = json.loads('{' + file_text.split('{', 1)[1].rsplit('}', 1)[0] + '}')
|
||||
model = data['config']['model_args'].split('=')[1]
|
||||
model = m2 if 'fp16' in model else m1
|
||||
if model not in model_data:
|
||||
model_data[model] = {}
|
||||
results = data['results']
|
||||
tasks = list(results.keys())
|
||||
assert len(tasks) == 1, 'Only one task supported'
|
||||
task = tasks[0]
|
||||
if task in model_data[model]:
|
||||
raise ValueError(f'Duplicate task {task}')
|
||||
task_version = data['versions'][task]
|
||||
results = results[task]
|
||||
results_data = {}
|
||||
for result_key in results:
|
||||
if result_key.endswith('_stderr'):
|
||||
continue
|
||||
result_value = results[result_key]
|
||||
results_data[result_key] = {'value': result_value}
|
||||
stderr_key = f'{result_key}_stderr'
|
||||
if stderr_key in results:
|
||||
results_data[result_key]['stderr'] = results[stderr_key]
|
||||
else:
|
||||
logger.warning(f'No stderr for {result_key} in {results}')
|
||||
model_data[model][task] = {'version': task_version, 'results': results_data}
|
||||
all_tasks.add(task)
|
||||
except Exception:
|
||||
logger.exception(f'Failed to parse {fn}')
|
||||
continue
|
||||
|
||||
all_models = list(sorted(model_data.keys()))
|
||||
table_data = []
|
||||
for task in all_tasks:
|
||||
try:
|
||||
higher_is_better = lm_eval.tasks.get_task(task).higher_is_better(None)
|
||||
except Exception:
|
||||
logger.warning(f'Failed to get higher_is_better for {task}')
|
||||
continue
|
||||
if any(task not in model_data[model] for model in all_models):
|
||||
logger.warning(f'No results for {task}')
|
||||
continue
|
||||
results = model_data[m1][task]['results']
|
||||
results2 = model_data[m2][task]['results']
|
||||
for metric in results:
|
||||
result_value = results[metric]['value']
|
||||
stderr_value = results[metric].get('stderr', 0.0)
|
||||
result2_value = results2[metric]['value']
|
||||
stderr2_value = results2[metric].get('stderr', 0.0)
|
||||
significance = (result_value - result2_value) / ((stderr_value + stderr2_value + 1e-8) / 2)
|
||||
if higher_is_better[metric]:
|
||||
significance *= -1
|
||||
if abs(significance) > 1:
|
||||
significant = '+' if significance > 0 else '-'
|
||||
else:
|
||||
significant = ''
|
||||
table_data.append([task, metric, result_value, stderr_value, result2_value, stderr2_value, significant])
|
||||
|
||||
table_str = tabulate(table_data, headers=['Task', 'Metric', m1, 'stderr', m2, 'stderr', 'Significant'], tablefmt='pipe')
|
||||
print(table_str)
|
||||
Path('./results.table.txt').write_text(table_str)
|
|
@ -0,0 +1,59 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import bs4
|
||||
from loguru import logger
|
||||
import multiprocessing as mp
|
||||
import tqdm
|
||||
from absl import app, flags
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", category=bs4.MarkupResemblesLocatorWarning, module='bs4')
|
||||
|
||||
DATA_FN = '../tmp/pol_062016-112019_labeled.ndjson'
|
||||
OUT_FN = '../tmp/kek.txt'
|
||||
|
||||
flags.DEFINE_string('data_fn', DATA_FN, 'data file')
|
||||
flags.DEFINE_string('out_fn', OUT_FN, 'output file')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# from here: https://gist.github.com/zmwangx/ad0830ba94b1fd98f428
|
||||
def text_with_newlines(elem):
|
||||
text = ''
|
||||
for e in elem.descendants:
|
||||
if isinstance(e, str):
|
||||
# text += e.strip()
|
||||
text += e
|
||||
elif e.name == 'br' or e.name == 'p':
|
||||
text += '\n'
|
||||
return text
|
||||
|
||||
|
||||
def parse_line(line):
|
||||
data = json.loads(line)
|
||||
posts_text = []
|
||||
for post in data.get('posts', []):
|
||||
try:
|
||||
if 'com' in post:
|
||||
soup = bs4.BeautifulSoup(post['com'], 'lxml')
|
||||
post_text = text_with_newlines(soup).strip()
|
||||
else:
|
||||
post_text = ''
|
||||
post_text = f'--- {post["no"]}\n{post_text}'
|
||||
posts_text.append(post_text)
|
||||
except Exception:
|
||||
logger.exception(f'failed to parse post {post}')
|
||||
return '\n'.join(posts_text)
|
||||
|
||||
|
||||
def main(_):
|
||||
with open(FLAGS.out_fn, 'w') as out_f:
|
||||
with open(FLAGS.data_fn) as in_f:
|
||||
with mp.Pool() as pool:
|
||||
for parsed_line in pool.imap(parse_line, tqdm.tqdm(in_f)):
|
||||
out_f.write(parsed_line + '\n-----\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
|
@ -0,0 +1,11 @@
|
|||
clone https://github.com/kingoflolz/mesh-transformer-jax and put this code inside
|
||||
|
||||
`model` is from https://github.com/okbuddyhololive/project-cybertard with slight changes
|
||||
|
||||
(you only need to do the above things if you want to run jax inference. for hugging face, it is not necessary)
|
||||
|
||||
then run `uvicorn --host 0.0.0.0 --port 8080 serve_api:app`
|
||||
|
||||
I use python 3.9.12 and install requirements.txt, then uninstall jax, jaxlib, tensorflow, and tensorflow-cpu
|
||||
|
||||
install `jax==0.2.12 jaxlib==0.1.67 tensorflow==2.5.0 markupsafe==2.0.1 uvicorn fastapi loguru`
|
|
@ -0,0 +1,2 @@
|
|||
from .inference import Inference
|
||||
from .constants import ModelParams, InferConfig
|
|
@ -0,0 +1,45 @@
|
|||
import typing
|
||||
from dataclasses import dataclass
|
||||
|
||||
import optax
|
||||
|
||||
BAD_WORDS = [] # Can't be part of config to avoid printing it
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferConfig:
|
||||
name: str = "Holotard"
|
||||
prompt_length: int = 65536
|
||||
token_length: int = 16
|
||||
|
||||
response_probability: float = 0.02
|
||||
top_p: float = 1.0
|
||||
|
||||
min_temperature: float = 0.6
|
||||
max_temperature: float = 1.2
|
||||
|
||||
max_same_replies: int = 2
|
||||
same_reply_saved_messages: int = 6
|
||||
max_response_retries: int = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelParams:
|
||||
layers: int = 28
|
||||
d_model: int = 4096
|
||||
n_heads: int = 16
|
||||
n_vocab: int = 50400
|
||||
|
||||
norm: str = "layernorm"
|
||||
pe: str = "rotary"
|
||||
pe_rotary_dims: int = 64
|
||||
|
||||
seq: int = 2048
|
||||
cores_per_replica: int = 8
|
||||
per_replica_batch: int = 1
|
||||
|
||||
# batch size of 2 needs 200gb, 1 needs <16. wtf
|
||||
optimizer: optax.chain = optax.chain(optax.adaptive_grad_clip(0.001), optax.centralize(),
|
||||
optax.scale_by_adam(0.99, 0.999), optax.additive_weight_decay(1e-3),
|
||||
optax.scale(-1e-5), )
|
||||
sampler = None
|
|
@ -0,0 +1,111 @@
|
|||
import random
|
||||
from typing import Any, Optional
|
||||
from loguru import logger
|
||||
import time
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import transformers
|
||||
from jax import numpy as jnp
|
||||
from jax.experimental import maps
|
||||
from mesh_transformer.checkpoint import read_ckpt_lowmem
|
||||
from mesh_transformer.sampling import nucleaus_sample
|
||||
from mesh_transformer.transformer_shard import CausalTransformer
|
||||
|
||||
from .constants import ModelParams, InferConfig
|
||||
|
||||
|
||||
def default(value: Any, fallback: Any) -> Any:
|
||||
# luke prefers making a function that chooses between `value` and `feedback` so i am gonna keep it
|
||||
if value is None:
|
||||
return fallback
|
||||
|
||||
return value
|
||||
|
||||
|
||||
_cores_per_replica = ModelParams.cores_per_replica
|
||||
_mesh_shape = (jax.device_count() // _cores_per_replica, _cores_per_replica)
|
||||
_devices = np.array(jax.devices()).reshape(_mesh_shape)
|
||||
|
||||
#maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(_devices, ("dp", "mp")), ())
|
||||
maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(_devices, ("dp", "mp")))
|
||||
|
||||
|
||||
class Inference:
|
||||
_NP_ONE = np.ones((1,))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: Optional[str] = None,
|
||||
parameters: Optional[ModelParams] = None,
|
||||
config: Optional[InferConfig] = None,
|
||||
):
|
||||
path = "checkpoint_slim/" if path is None else path
|
||||
|
||||
self.params = ModelParams() if parameters is None else parameters
|
||||
self.params.sampler = nucleaus_sample
|
||||
self.config = InferConfig() if config is None else config
|
||||
|
||||
self.model = CausalTransformer(self.params.__dict__)
|
||||
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
self.model.state = read_ckpt_lowmem(
|
||||
self.model.state, path, self.params.cores_per_replica, load_opt=False
|
||||
)
|
||||
|
||||
def generate_tokens(
|
||||
self,
|
||||
prompt: np.ndarray,
|
||||
length: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
) -> np.ndarray:
|
||||
length = default(length, self.config.token_length)
|
||||
top_p = default(top_p, self.config.top_p)
|
||||
new_temp = random.random() * (self.config.max_temperature - self.config.min_temperature)
|
||||
new_temp += self.config.min_temperature
|
||||
temperature = default(temperature, new_temp)
|
||||
#prompt = prompt[:, -2048:]
|
||||
#prompt = prompt[:, -length:]
|
||||
|
||||
start_time = time.time()
|
||||
source = jnp.array(
|
||||
np.pad(
|
||||
prompt,
|
||||
(
|
||||
(0, 0),
|
||||
(self.params.seq - prompt.shape[1], 0),
|
||||
),
|
||||
)
|
||||
)
|
||||
logger.info(f"creating source took {time.time() - start_time}")
|
||||
sampler_options = {
|
||||
"top_p": self._NP_ONE * top_p,
|
||||
"temp": self._NP_ONE * temperature,
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
#with jax.experimental.maps.mesh(_devices, ("dp", "mp")):
|
||||
logger.info(f"creating mesh took {time.time() - start_time}")
|
||||
start_time = time.time()
|
||||
out = self.model.generate(
|
||||
source, self._NP_ONE * prompt.shape[1], length, sampler_options
|
||||
)
|
||||
logger.info(f"generate took {time.time() - start_time}")
|
||||
|
||||
#import IPython; IPython.embed()
|
||||
return out[1][0][0, :, 0]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
length: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
) -> str:
|
||||
inp_tokens = self.tokenizer([prompt], verbose=False, return_tensors="np")
|
||||
inp_tokens = inp_tokens["input_ids"][0]
|
||||
out_tokens = self.generate_tokens(
|
||||
inp_tokens.reshape(1, -1), length, top_p, temperature
|
||||
)
|
||||
|
||||
return self.tokenizer.decode(out_tokens)
|
|
@ -0,0 +1,52 @@
|
|||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import optax
|
||||
|
||||
from mesh_transformer import util
|
||||
from mesh_transformer.checkpoint import read_ckpt, write_ckpt, read_ckpt_lowmem
|
||||
from mesh_transformer.transformer_shard import CausalTransformer
|
||||
from smart_open import open
|
||||
|
||||
from mesh_transformer.util import clip_by_global_norm, to_bf16, to_f16
|
||||
from model.constants import ModelParams
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
params = ModelParams().__dict__
|
||||
convert_fn = to_bf16
|
||||
|
||||
cores_per_replica = params["cores_per_replica"]
|
||||
|
||||
assert cores_per_replica <= 8
|
||||
|
||||
start = time.time()
|
||||
print(f"jax devices: {jax.device_count()}")
|
||||
print(f"jax runtime initialized in {time.time() - start:.06}s")
|
||||
|
||||
mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
|
||||
devices = np.array(jax.devices()).reshape(mesh_shape)
|
||||
|
||||
with jax.experimental.maps.mesh(devices, ("dp", "mp")):
|
||||
network = CausalTransformer(params)
|
||||
|
||||
start = time.time()
|
||||
network.state = read_ckpt(
|
||||
network.state, f"checkpoint/", devices.shape[1], load_opt=False
|
||||
)
|
||||
print(f"network loaded in {time.time() - start:.06}s")
|
||||
|
||||
start = time.time()
|
||||
del network.state["opt_state"]
|
||||
|
||||
network.state["params"] = convert_fn(network.state["params"])
|
||||
print(f"network converted in {time.time() - start:.06}s")
|
||||
|
||||
suffix = "_slim"
|
||||
|
||||
for i in range(cores_per_replica):
|
||||
write_ckpt(network.state, f"checkpoint_slim/", i)
|
||||
print(f"written shard {i}")
|
|
@ -0,0 +1,199 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from typing import Optional
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
import contextlib
|
||||
|
||||
import pydantic
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
origins = ["*"]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
class Settings(pydantic.BaseSettings):
|
||||
queue_size: int = 1024
|
||||
log_file: str = "logs/serve_api.log"
|
||||
api_keys_file: str = 'valid_api_keys.txt'
|
||||
hf_model: str = ''
|
||||
hf_cuda: bool = False
|
||||
pre_prompt_length: int = 512
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
def _check_api_key(key):
|
||||
key = key.strip()
|
||||
for line in Path(settings.api_keys_file).open():
|
||||
if not line:
|
||||
continue
|
||||
valid_key = line.split()[0]
|
||||
if key == valid_key:
|
||||
break
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
request_queue = queue.Queue(maxsize=settings.queue_size)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def jax_generation():
|
||||
from model import inference
|
||||
import jax
|
||||
model = inference.Inference(path="../model_slim/step_88001/")
|
||||
|
||||
def _generate(request):
|
||||
response = model.generate(
|
||||
prompt=request.prompt,
|
||||
length=request.length,
|
||||
top_p=request.top_p,
|
||||
temperature=request.temperature,
|
||||
)
|
||||
return response
|
||||
with jax.experimental.maps.mesh(inference._devices, ("dp", "mp")):
|
||||
yield _generate
|
||||
|
||||
@contextlib.contextmanager
|
||||
def hf_generation():
|
||||
from transformers import GPTJForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
if settings.hf_cuda:
|
||||
model = GPTJForCausalLM.from_pretrained(
|
||||
settings.hf_model, revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
)
|
||||
model.cuda()
|
||||
else:
|
||||
model = GPTJForCausalLM.from_pretrained( settings.hf_model, torch_dtype=torch.float32)
|
||||
model.eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
|
||||
|
||||
def _generate(request: CompleteRequest):
|
||||
input_ids = tokenizer(request.prompt, return_tensors="pt").input_ids
|
||||
|
||||
max_prompt_length = 2048 - request.length
|
||||
input_ids = input_ids[:, -max_prompt_length:]
|
||||
|
||||
if request.pre_prompt:
|
||||
pp_input_ids = tokenizer(request.pre_prompt, return_tensors="pt").input_ids
|
||||
pp_input_ids = pp_input_ids[:, :settings.pre_prompt_length]
|
||||
input_ids = input_ids[:, -(max_prompt_length-len(pp_input_ids)):]
|
||||
full_prompt = tokenizer.batch_decode(pp_input_ids)[0] + tokenizer.batch_decode(input_ids)[0]
|
||||
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
|
||||
input_ids = input_ids[:, -max_prompt_length:]
|
||||
|
||||
|
||||
if settings.hf_cuda:
|
||||
input_ids = input_ids.cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
gen_tokens = model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
typical_p=request.typical_p,
|
||||
max_new_tokens=request.length,
|
||||
).detach().cpu()
|
||||
gen_text = tokenizer.batch_decode(gen_tokens)[0]
|
||||
prompt_decoded = tokenizer.batch_decode(input_ids.detach().cpu())[0]
|
||||
if not gen_text.startswith(prompt_decoded):
|
||||
raise Exception(f"Generated text does not start with prompt: {gen_text}\n(prompt was {prompt_decoded})")
|
||||
gen_text = gen_text[len(prompt_decoded):]
|
||||
return gen_text
|
||||
yield _generate
|
||||
|
||||
def worker():
|
||||
if settings.hf_model:
|
||||
generation = hf_generation
|
||||
else:
|
||||
generation = jax_generation
|
||||
with generation() as generate_fn:
|
||||
with open(settings.log_file, "a") as logf:
|
||||
while True:
|
||||
response_queue = None
|
||||
try:
|
||||
start_time = time.time()
|
||||
(request, response_queue) = request_queue.get()
|
||||
logger.info(f"getting request took {time.time() - start_time}")
|
||||
start_time = time.time()
|
||||
response = generate_fn(request)
|
||||
logger.info(f"generate took {time.time() - start_time}, response length: {len(response)}")
|
||||
start_time = time.time()
|
||||
|
||||
logf.write(f"##### {request.api_key} ##### {time.time()} #####\n")
|
||||
logf.write(f"{request.pre_prompt}\n")
|
||||
logf.write("###\n")
|
||||
logf.write(f"{request.prompt}\n")
|
||||
logf.write("#####\n")
|
||||
logf.write(f"{response}\n\n")
|
||||
logf.flush()
|
||||
|
||||
logger.info(f"writing log took {time.time() - start_time}")
|
||||
start_time = time.time()
|
||||
response_queue.put(response)
|
||||
logger.info(f"putting response took {time.time() - start_time}")
|
||||
except KeyboardInterrupt:
|
||||
logger.info(f"Got KeyboardInterrupt... quitting!")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"Got exception, will continue")
|
||||
if response_queue is not None:
|
||||
response_queue.put("")
|
||||
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def main():
|
||||
return {"response": "Hello, world!"}
|
||||
|
||||
class CompleteRequest(pydantic.BaseModel):
|
||||
prompt: pydantic.constr(min_length=0, max_length=2**14)
|
||||
pre_prompt: pydantic.constr(min_length=0, max_length=2**14) = ''
|
||||
api_key: pydantic.constr(min_length=1, max_length=128) = "x"*9
|
||||
length: pydantic.conint(ge=1, le=1024) = 128
|
||||
top_p: pydantic.confloat(ge=0.0, le=1.0) = 1.0
|
||||
temperature: pydantic.confloat(ge=0.0) = 1.0
|
||||
typical_p: pydantic.confloat(ge=0.0, le=1.0) = 1.0
|
||||
|
||||
def _enqueue(request: CompleteRequest):
|
||||
response_queue = queue.Queue()
|
||||
request_queue.put((request, response_queue))
|
||||
response = response_queue.get()
|
||||
return response
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup():
|
||||
threading.Thread(
|
||||
target=worker,
|
||||
daemon=True,
|
||||
).start()
|
||||
_enqueue(CompleteRequest(prompt="hello"))
|
||||
|
||||
@app.post("/complete")
|
||||
def complete(request: CompleteRequest):
|
||||
logger.info(f"Received request from key {request.api_key}. Queue size is {request_queue.qsize()}")
|
||||
if request_queue.full():
|
||||
logger.warning("Request queue full.")
|
||||
raise ValueError("Request queue full.")
|
||||
if not _check_api_key(request.api_key):
|
||||
logger.warning(f"api key not valid: {request.api_key}, discarding...")
|
||||
raise ValueError("Invalid API key")
|
||||
response = _enqueue(request)
|
||||
return {"response": response}
|
|
@ -0,0 +1,65 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from absl import flags, app
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import functools
|
||||
|
||||
import tokenizers
|
||||
import tensorflow as tf
|
||||
import tqdm
|
||||
|
||||
flags.DEFINE_string('txt_fn', '../tmp/kek.txt', 'input txt')
|
||||
flags.DEFINE_string('out_dir', '../tmp/tfrecords/', 'output directory (will be cleared)')
|
||||
flags.DEFINE_integer('chunk_size', 2**24, 'how many tokens go into one tfrecords file')
|
||||
flags.DEFINE_integer('read_buffer_size', 2**10, 'input file read buffer size')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def get_tokenizer():
|
||||
return tokenizers.Tokenizer.from_pretrained('gpt2')
|
||||
|
||||
|
||||
def make_record_file(record_ids, out_dir, file_no):
|
||||
out_fn = str(out_dir / f'tokens-{file_no:05d}.tfrecord')
|
||||
with tf.io.TFRecordWriter(out_fn) as writer:
|
||||
feature = {'text': tf.train.Feature(int64_list=tf.train.Int64List(value=record_ids))}
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
|
||||
writer.write(tf_example.SerializeToString())
|
||||
|
||||
|
||||
def read_in_blocks(f):
|
||||
while True:
|
||||
block = f.read(FLAGS.read_buffer_size)
|
||||
if not block:
|
||||
break
|
||||
yield block
|
||||
|
||||
def main(_):
|
||||
out_dir = Path(FLAGS.out_dir)
|
||||
if out_dir.exists():
|
||||
logger.warning(f'clearing {out_dir}')
|
||||
shutil.rmtree(out_dir)
|
||||
out_dir.mkdir(exist_ok=True)
|
||||
tokenizer = get_tokenizer()
|
||||
with open(FLAGS.txt_fn) as in_f:
|
||||
current_ids = []
|
||||
out_file_no = 0
|
||||
for block in tqdm.tqdm(read_in_blocks(in_f)):
|
||||
current_ids.extend(tokenizer.encode(block).ids)
|
||||
while len(current_ids) >= FLAGS.chunk_size:
|
||||
record_ids, current_ids = current_ids[:FLAGS.chunk_size], current_ids[FLAGS.chunk_size:]
|
||||
make_record_file(record_ids, out_dir, out_file_no)
|
||||
out_file_no += 1
|
||||
|
||||
if current_ids:
|
||||
make_record_file(current_ids, out_dir, out_file_no)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
Loading…
Reference in New Issue